In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import scanpy as sc
import anndata as ad
import pooch

from hv_cancer_modules import DotPlot, UMAPPlot
import holoviews as hv 
hv.extension('bokeh')

In [None]:
sc.settings.set_figure_params(dpi=50, facecolor="white")

The data used in this basic preprocessing and clustering tutorial was collected from bone marrow mononuclear cells of healthy human donors and was part of [openproblem's NeurIPS 2021 benchmarking dataset](https://openproblems.bio/competitions/neurips_2021/) {cite}`Luecken2021`. The samples used in this tutorial were measured using the 10X Multiome Gene Expression and Chromatin Accessability kit. 


We are reading in the count matrix into an [AnnData](https://anndata.readthedocs.io/en/latest/tutorials/notebooks/getting-started.html) object, which holds many slots for annotations and different representations of the data.

In [None]:
EXAMPLE_DATA = pooch.create(
    path=pooch.os_cache("scverse_tutorials"),
    base_url="doi:10.6084/m9.figshare.22716739.v1/",
)
EXAMPLE_DATA.load_registry_from_doi()

In [None]:
%%time

samples = {
    "s1d1": "s1d1_filtered_feature_bc_matrix.h5",
    "s1d3": "s1d3_filtered_feature_bc_matrix.h5",
}
adatas = {}

for sample_id, filename in samples.items():
    path = EXAMPLE_DATA.fetch(filename)
    sample_adata = sc.read_10x_h5(path)
    sample_adata.var_names_make_unique()
    adatas[sample_id] = sample_adata

adata = ad.concat(adatas, label="sample")
adata.obs_names_make_unique()
print(adata.obs["sample"].value_counts())
adata

The data contains ~8,000 cells per sample and 36,601 measured genes. We'll now investigate these with a basic preprocessing and clustering workflow.

## Quality Control

The scanpy function {func}`~scanpy.pp.calculate_qc_metrics` calculates common quality control (QC) metrics, which are largely based on `calculateQCMetrics` from scater {cite}`McCarthy2017`. One can pass specific gene population to {func}`~scanpy.pp.calculate_qc_metrics` in order to calculate proportions of counts for these populations. Mitochondrial, ribosomal and hemoglobin genes are defined by distinct prefixes as listed below. 

In [None]:
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith("MT-")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains("^HB[^(P)]")

In [None]:
sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
)

One can now inspect violin plots of some of the computed QC metrics:

* the number of genes expressed in the count matrix
* the total counts per cell
* the percentage of counts in mitochondrial genes

In [None]:
adata

In [None]:
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt"],
    jitter=0.4,
    multi_panel=True,
)

With HoloViz:

In [None]:
violins = []
for i in ["obs.n_genes_by_counts", "obs.total_counts", "obs.pct_counts_mt"]:
    violins.append(
        hv.Violin(adata, vdims=i).opts(
            ylabel='Value',
            title=i.split('.')[-1],
            show_grid=True,
            ylim=(0,None),
        )
    )
hv.Layout(violins).opts(axiswise=True)

Additionally, it is useful to consider QC metrics jointly by inspecting a scatter plot colored by `pct_counts_mt`. 

In [None]:
sc.pl.scatter(adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

With HoloViz:

In [None]:
scatter = (hv.Scatter(adata, "obs.total_counts", ["obs.n_genes_by_counts", "obs.pct_counts_mt"])
    .opts(cmap="Viridis",
          color="obs.pct_counts_mt",
          colorbar=True,
          width=400,
          tools=['hover'],
          show_grid=True,
          title='pct_counts_mt',
))
scatter

Based on the QC metric plots, one could now remove cells that have too many mitochondrial genes expressed or too many total counts by setting manual or automatic thresholds. However, sometimes what appears to be poor QC metrics can be driven by real biology so we suggest starting with a very permissive filtering strategy and revisiting it at a later point. We therefore now only filter cells with less than 100 genes expressed and genes that are detected in less than 3 cells. 

Additionally, it is important to note that for datasets with multiple batches, quality control should be performed for each sample individually as quality control thresholds can very substantially between batches. 

In [None]:
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)

### Doublet detection

As a next step, we run a doublet detection algorithm. Identifying doublets is crucial as they can lead to misclassifications or distortions in downstream analysis steps. Scanpy contains the doublet detection method Scrublet {cite}`Wolock2019`. Scrublet predicts cell doublets using a nearest-neighbor classifier of observed transcriptomes and simulated doublets. {func}`scanpy.pp.scrublet` adds `doublet_score` and `predicted_doublet` to `.obs`. One can now either filter directly on `predicted_doublet` or use the `doublet_score` later during clustering to filter clusters with high doublet scores. 

In [None]:
%%time

sc.pp.scrublet(adata, batch_key="sample")

We can remove doublets by either filtering out the cells called as doublets, or waiting until we've done a clustering pass and filtering out any clusters with high doublet scores.

:::{seealso}
Alternative methods for doublet detection within the scverse ecosystem are [DoubletDetection](https://github.com/JonathanShor/DoubletDetection) and [SOLO](https://docs.scvi-tools.org/en/stable/user_guide/models/solo.html). You can read more about these in the [Doublet Detection chapter](https://www.sc-best-practices.org/preprocessing_visualization/quality_control.html#doublet-detection) of Single Cell Best Practices.
:::

## Normalization

The next preprocessing step is normalization. A common approach is count depth scaling with subsequent log plus one (log1p) transformation. Count depth scaling normalizes the data to a “size factor” such as the median count depth in the dataset, ten thousand (CP10k) or one million (CPM, counts per million). The size factor for count depth scaling can be controlled via `target_sum` in `pp.normalize_total`. We are applying median count depth normalization with log1p transformation (AKA log1PF).

In [None]:
# Saving count data
adata.layers["counts"] = adata.X.copy()

In [None]:
# Normalizing to median total counts
sc.pp.normalize_total(adata)
# Logarithmize the data
sc.pp.log1p(adata)

## Feature selection

As a next step, we want to reduce the dimensionality of the dataset and only include the most informative genes. This step is commonly known as feature selection. The scanpy function `pp.highly_variable_genes` annotates highly variable genes by reproducing the implementations of Seurat {cite}`Satija2015`, Cell Ranger {cite}`Zheng2017`, and Seurat v3 {cite}`Stuart2019` depending on the chosen `flavor`. 

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=2000, batch_key="sample")

In [None]:
sc.pl.highly_variable_genes(adata)

TODO: with HoloViz:

In [None]:
adata.var['highly_variable']
adata.var['means']

## Dimensionality Reduction
Reduce the dimensionality of the data by running principal component analysis (PCA), which reveals the main axes of variation and denoises the data.

In [None]:
sc.tl.pca(adata)

Let us inspect the contribution of single PCs to the total variance in the data. This gives us information about how many PCs we should consider in order to compute the neighborhood relations of cells, e.g. used in the clustering function {func}`~scanpy.tl.leiden` or {func}`~scanpy.tl.tsne`. In our experience, there does not seem to be signifigant downside to overestimating the numer of principal components.

In [None]:
sc.pl.pca_variance_ratio(adata, n_pcs=50, log=True)

with HoloViz:

In [None]:
arr = adata.uns['pca']['variance_ratio']

x_positions = np.arange(1, len(arr) + 1)
labels = ['PC{}'.format(i) for i in range(1, len(arr) + 1)]

points = hv.Points((x_positions, arr), ['Ranking', 'Variance Ratio']).opts(color='Variance Ratio', cmap='magma_r', size=6, tools=['hover'])
label_overlay = hv.Labels((x_positions, arr, labels), ['x', 'y'], 'text').opts(
    text_font_size='8pt', text_align='left', yoffset=.003, angle=90, text_color='black')

var_ratio_plot = (points * label_overlay).opts(
    width=600, height=400, xlabel='Ranking', ylabel=None, title='Variance Ratio', show_grid=True)

var_ratio_plot

You can also plot the principal components to see if there are any potentially undesired features (e.g. batch, QC metrics) driving signifigant variation in this dataset. In this case, there isn't anything too alarming, but it's a good idea to explore this.

In [None]:
sc.pl.pca(
    adata,
    color=["sample", "sample", "pct_counts_mt", "pct_counts_mt"],
    dimensions=[(0, 1), (2, 3), (0, 1), (2, 3)],
    ncols=2,
    size=2,
)

with HoloViz v1:

In [None]:
import holoviews.operation.datashader as hd
import datashader as ds
import colorcet as cc
import panel as pn
import numpy as np
import holoviews as hv


def create_hv_dr_plot(x_data, color_data, x_dim, y_dim, color_var, xaxis_label, yaxis_label, width=300, height=300, datashading=False):
    """
    Helper function to create a single dimensionality reduction scatter plot with Holoviews and optional datashading.

    Parameters:
    - x_data: numpy.ndarray, dimensionality reduction data.
    - color_data: numpy.ndarray, color values (categorical or continuous).
    - x_dim, y_dim: int, indices for components.
    - color_var: str, the name of the coloring variable.
    - xaxis_label, yaxis_label: str, labels for the axes.
    - width, height: int, dimensions of the plot.
    - datashading: bool, whether to apply Datashader.

    Returns:
    - hv.Element: A configured HoloViews scatter plot or datashaded plot.
    """
    # print(color_var, color_data.dtype.name)
    cats = ['category', 'categorical', 'bool']
    if color_data.dtype.name in cats:  # Categorical data
        # print('categorical')
        cmap = cc.b_glasbey_category10
        show_legend = True
        colorbar = False
    else:  # Continuous data
        # print('continuous')
        cmap = 'Viridis_r'
        show_legend = False
        colorbar = True

    # Create the base plot
    plot = hv.Points(
        (x_data[:, x_dim], x_data[:, y_dim], color_data),
        [xaxis_label, yaxis_label], color_var
    ).opts(
        color=color_var,
        cmap=cmap,
        size=1,
        alpha=0.5,
        colorbar=colorbar,
        padding=0,
        tools=['hover'],
        show_legend=show_legend,
    )

    if datashading:
        if color_data.dtype.name in cats:
            plot = hd.rasterize(plot, aggregator=ds.by(color_var, ds.count())).opts(cmap=cmap)
            plot = hd.dynspread(plot, threshold=.5)
            # Create a fake legend overlay for categorical case
            unique_categories = np.unique(color_data)
            color_key = dict(zip(unique_categories, (cmap[i % len(cmap)] for i in range(len(unique_categories)))))
            legend_overlay = hv.NdOverlay({
                str(cat): hv.Points([(0, 0)], label=str(cat)).opts(
                    color=color_key[cat],
                    size=0,
                    legend_position='right',
                    legend_cols = len(unique_categories) // 5,
                )
                for cat in unique_categories
            })
            plot = (plot * legend_overlay).opts(show_legend=True, legend_limit = len(unique_categories)+1) # legend still not displaying when more than a few
            # print(unique_categories)
        else:
            plot = hd.rasterize(plot, aggregator=ds.mean(color_var)).opts(cmap=cmap, colorbar=colorbar)
            plot = hd.dynspread(plot, threshold=.5)
    return plot.opts(
        title=f"{color_var}",
        tools=['hover'],
        show_legend=show_legend,
        width=width,
        height=height
    )


def hv_dr_plot_layout(adata, color, dimensions, dr_method=('X_pca', 'PCA')):
    """
    Create a static layout of dimensionality reduction scatter plots.

    Parameters:
    - adata: AnnData object containing dimensionality reduction data in `.obsm` and metadata in `.obs`.
    - color: list of str, coloring variables for each plot.
    - dimensions: list of tuples, specifying component indices.
    - dr_method: tuple, (key for the dimensionality reduction data in `.obsm`, label for the method).

    Returns:
    - hv.Layout: A layout of scatter plots.
    """
    dr_key, dr_label = dr_method
    plots = []
    x_data = adata.obsm[dr_key]
    for color_var, (x_dim, y_dim) in zip(color, dimensions):
        color_data = adata.obs[color_var].values
        plot = create_hv_dr_plot(
            x_data, color_data, x_dim, y_dim, color_var,
            f'{dr_label} {x_dim + 1}', f'{dr_label} {y_dim + 1}', width=350, height=300,
            datashading=True,
        )
        plots.append(plot)

    layout = hv.Layout(plots).opts(shared_axes=False, axiswise=True)
    return layout


layout = hv_dr_plot_layout(
    adata=adata,
    color=["sample", "sample", "pct_counts_mt", "pct_counts_mt"],
    dimensions=[(0, 1), (2, 3), (0, 1), (2, 3)],
    dr_method=('X_pca', 'PCA')
)
layout.cols(2)

with HoloViz v2 (UI for axes and coloring):

In [None]:
from panel.io import hold

def create_dr_app(adata, datashade=True):
    dr_options = {key: key.split('_')[1].upper() for key in adata.obsm.keys()}
    initial_dr_key = list(dr_options.keys())[0]  # Use the first available reduction as default
    x_data = adata.obsm[initial_dr_key]
    num_dims = x_data.shape[1]
    dim_options = [f"{dr_options[initial_dr_key]} {i+1}" for i in range(num_dims)]
    color_options = list(adata.obs.columns)  # Use all columns from `.obs`

    dr_select = pn.widgets.Select(name='Dimensionality Reduction', options=list(dr_options.keys()), value=initial_dr_key)
    xaxis = pn.widgets.Select(name='X-axis', options=dim_options, value=dim_options[0])
    yaxis = pn.widgets.Select(name='Y-axis', options=dim_options, value=dim_options[1])
    color = pn.widgets.Select(name='Color', options=color_options, value=color_options[0])
    datashading_switch = pn.widgets.Checkbox(name='Enable Datashader', value=datashade)

    def update_plot(dr_select, xaxis, yaxis, color, datashading):
        x_data = adata.obsm[dr_select]
        x_dim = int(xaxis.split()[-1]) - 1
        y_dim = int(yaxis.split()[-1]) - 1
        color_data = adata.obs[color].values
        dr_label = dr_options[dr_select]

        return create_hv_dr_plot(
            x_data, color_data, x_dim, y_dim, color,
            xaxis, yaxis, width=550, height=500, datashading=datashading
        )

    # Update all axis options (simultaneously with @hold) when dimensionality reduction method changes
    @hold()
    def update_axis_options(event):
        x_data = adata.obsm[event.new]
        num_dims = x_data.shape[1]
        dr_label = dr_options[event.new]
        dim_options = [f"{dr_label} {i+1}" for i in range(num_dims)]
        xaxis.options = dim_options
        yaxis.options = dim_options
        xaxis.value = dim_options[0]
        yaxis.value = dim_options[1]

    dr_select.param.watch(update_axis_options, 'value')

    def enforce_different_axes(event):
        if xaxis.value == yaxis.value:
            available_options = [opt for opt in yaxis.options if opt != xaxis.value]
            if available_options:
                yaxis.value = available_options[0]

    xaxis.param.watch(enforce_different_axes, 'value')
    yaxis.param.watch(enforce_different_axes, 'value')

    plot_pane = pn.bind(update_plot, dr_select=dr_select, xaxis=xaxis, yaxis=yaxis, color=color, datashading=datashading_switch)

    app = pn.Row(pn.WidgetBox(dr_select, xaxis, yaxis, color, datashading_switch), plot_pane)
    return app


interactive_app = create_dr_app(adata, datashade=True)
interactive_app.servable()

## Nearest neighbor graph constuction and visualization

Let us compute the neighborhood graph of cells using the PCA representation of the data matrix.

In [None]:
sc.pp.neighbors(adata)

This graph can then be embedded in two dimensions for visualiztion with UMAP (McInnes et al., 2018):

In [None]:
sc.tl.umap(adata)

We can now visualize the UMAP according to the `sample`. 

In [None]:
sc.pl.umap(
    adata,
    color="sample",
    # Setting a smaller point size to prevent overlap
    size=2,
)

with HoloViz:

In [None]:
layout = hv_dr_plot_layout(
    adata=adata,
    color=["sample", "pct_counts_mt"],
    dimensions=[(0, 1), (0, 1)],
    dr_method=('X_umap', 'UMAP')
)
layout

In [None]:
interactive_app.objects[0][0].value='X_umap'
interactive_app

Even though the data considered in this tutorial includes two different samples, we only observe a minor batch effect and we can continue with clustering and annotation of our data. 

If you inspect batch effects in your UMAP it can be beneficial to integrate across samples and perform batch correction/integration. We recommend checking out [`scanorama`](https://github.com/brianhie/scanorama) and [`scvi-tools`](https://scvi-tools.org) for batch integration.

## Clustering

As with Seurat and many other frameworks, we recommend the Leiden graph-clustering method (community detection based on optimizing modularity) {cite}`Traag2019`. Note that Leiden clustering directly clusters the neighborhood graph of cells, which we already computed in the previous section.

In [None]:
%%time

# Using the igraph implementation and a fixed number of iterations can be significantly faster, especially for larger datasets
sc.tl.leiden(adata, flavor="igraph", n_iterations=2)

In [None]:
sc.pl.umap(adata, color=["leiden"])

with HoloViz:

In [None]:
layout = hv_dr_plot_layout(
    adata=adata,
    color=["sample", "leiden"],
    dimensions=[(0, 1), (0, 1)],
    dr_method=('X_umap', 'UMAP')
)
layout

In [None]:
app_w_leiden = create_dr_app(adata, datashade=True)
app_w_leiden.objects[0][0].value='X_umap'
app_w_leiden.objects[0][3].value='leiden'
app_w_leiden.servable()

## Re-assess quality control and cell filtering 

As indicated before, we will now re-assess our filtering strategy by visualizing different QC metrics using UMAP. 

In [None]:
sc.pl.umap(
    adata,
    color=["leiden", "predicted_doublet", "doublet_score"],
    # increase horizontal space between panels
    wspace=0.5,
    size=3,
)

In [None]:
(0, 1)*2

In [None]:
layout = hv_dr_plot_layout(
    adata=adata,
    color=["leiden", "predicted_doublet", "doublet_score"],
    dimensions=[(0, 1)] *3,
    dr_method=('X_umap', 'UMAP')
)
layout

In [None]:
sc.pl.umap(
    adata,
    color=["leiden", "log1p_total_counts", "pct_counts_mt", "log1p_n_genes_by_counts"],
    wspace=0.5,
    ncols=2,
)

In [None]:
layout = hv_dr_plot_layout(
    adata=adata,
    color=["leiden", "log1p_total_counts", "pct_counts_mt", "log1p_n_genes_by_counts"],
    dimensions=[(0, 1)]*4,
    dr_method=('X_umap', 'UMAP')
)
layout.cols(2)

## Manual cell-type annotation

:::{note}
This section of the tutorial is expanded upon using prior knowledge resources like automated assignment and gene enrichment in the scverse tutorial [here](https://scverse-tutorials.readthedocs.io/en/latest/notebooks/basic-scrna-tutorial.html#cell-type-annotation)
:::

Cell type annotation is laborous and repetitive task, one which typically requires multiple rounds of subclustering and re-annotation. It's difficult to show the entirety of the process in this tutorial, but we aim to show how the tools scanpy provides assist in this process.

We have now reached a point where we have obtained a set of cells with decent quality, and we can proceed to their annotation to known cell types. Typically, this is done using genes that are exclusively expressed by a given cell type, or in other words these genes are the marker genes of the cell types, and are thus used to distinguish the heterogeneous groups of cells in our data. Previous efforts have collected and curated various marker genes into available resources, such as [CellMarker](http://bio-bigdata.hrbmu.edu.cn/CellMarker/), [TF-Marker](http://bio.liclab.net/TF-Marker/), and [PanglaoDB](https://panglaodb.se/). The [cellxgene gene expression tool](https://cellxgene.cziscience.com/gene-expression) can also be quite useful to see which cell types a gene has been expressed in across many existing datasets.

Commonly and classically, cell type annotation uses those marker genes subsequent to the grouping of the cells into clusters. So, let's generate a set of clustering solutions which we can then use to annotate our cell types. Here, we will use the Leiden clustering algorithm which will extract cell communities from our nearest neighbours graph.

In [None]:
for res in [0.02, 0.5, 2.0]:
    sc.tl.leiden(
        adata, key_added=f"leiden_res_{res:4.2f}", resolution=res, flavor="igraph"
    )

Notably, the number of clusters that we define is largely arbitrary, and so is the `resolution` parameter that we use to control for it. As such, the number of clusters is ultimately bound to the stable and biologically-meaningful groups that we can ultimately distringuish, typically done by experts in the corresponding field or by using expert-curated prior knowledge in the form of markers.

In [None]:
sc.pl.umap(
    adata,
    color=["leiden_res_0.02", "leiden_res_0.50", "leiden_res_2.00"],
    legend_loc="on data",
)

In [None]:
sc.pl.umap(
    adata,
    color=["leiden_res_0.02", "leiden_res_0.50", "leiden_res_2.00"],
    legend_loc="on data",
)

Though UMAPs should not be over-interpreted, here we can already see that in the highest resolution our data is over-clustered, while the lowest resolution is likely grouping cells which belong to distinct cell identities.

### Marker gene set

Let's define a set of marker genes for the main cell types that we expect to see in this dataset. These were adapted from [Single Cell Best Practices annotation chapter](https://www.sc-best-practices.org/cellular_structure/annotation.html), for a more detailed overview and best practices in cell type annotation, we refer the user to it.

In [None]:
marker_genes = {
    "CD14+ Mono": ["FCN1", "CD14"],
    "CD16+ Mono": ["TCF7L2", "FCGR3A", "LYN"],
    # Note: DMXL2 should be negative
    "cDC2": ["CST3", "COTL1", "LYZ", "DMXL2", "CLEC10A", "FCER1A"],
    "Erythroblast": ["MKI67", "HBA1", "HBB"],
    # Note HBM and GYPA are negative markers
    "Proerythroblast": ["CDK6", "SYNGR1", "HBM", "GYPA"],
    "NK": ["GNLY", "NKG7", "CD247", "FCER1G", "TYROBP", "KLRG1", "FCGR3A"],
    "ILC": ["ID2", "PLCG2", "GNLY", "SYNE1"],
    "Naive CD20+ B": ["MS4A1", "IL4R", "IGHD", "FCRL1", "IGHM"],
    # Note IGHD and IGHM are negative markers
    "B cells": [
        "MS4A1",
        "ITGB1",
        "COL4A4",
        "PRDM1",
        "IRF4",
        "PAX5",
        "BCL11A",
        "BLK",
        "IGHD",
        "IGHM",
    ],
    "Plasma cells": ["MZB1", "HSP90B1", "FNDC3B", "PRDM1", "IGKC", "JCHAIN"],
    # Note PAX5 is a negative marker
    "Plasmablast": ["XBP1", "PRDM1", "PAX5"],
    "CD4+ T": ["CD4", "IL7R", "TRBC2"],
    "CD8+ T": ["CD8A", "CD8B", "GZMK", "GZMA", "CCL5", "GZMB", "GZMH", "GZMA"],
    "T naive": ["LEF1", "CCR7", "TCF7"],
    "pDC": ["GZMB", "IL3RA", "COBLL1", "TCF4"],
}

In [None]:
sc.pl.dotplot(adata, marker_genes, groupby="leiden_res_0.02", standard_scale="var")

There are fairly clear patterns of expression for our markers show here, which we can use to label our coarsest clustering with broad lineages.

In [None]:
adata.obs["cell_type_lvl1"] = adata.obs["leiden_res_0.02"].map(
    {
        "0": "Lymphocytes",
        "1": "Monocytes",
        "2": "Erythroid",
        "3": "B Cells",
    }
)

In [None]:
sc.pl.dotplot(adata, marker_genes, groupby="leiden_res_0.50", standard_scale="var")

This seems like a resolution that suitable to distinguish most of the different cell types in our data. As such, let's try to annotate those by manually using the dotplot above, together with the UMAP of our clusters. Ideally, one would also look specifically into each cluster, and attempt to subcluster those if required.

### Differentially-expressed Genes as Markers

Furthermore, one can also calculate marker genes per cluster and then look up whether we can link those marker genes to any known biology, such as cell types and/or states. This is typically done using simple statistical tests, such as Wilcoxon and t-test, for each cluster vs the rest.

In [None]:
# Obtain cluster-specific differentially expressed genes
sc.tl.rank_genes_groups(adata, groupby="leiden_res_0.50", method="wilcoxon")

We can then visualize the top 5 differentially-expressed genes on a dotplot.

In [None]:
sc.pl.rank_genes_groups_dotplot(
    adata, groupby="leiden_res_0.50", standard_scale="var", n_genes=5
)

We can then use these genes to figure out what cell types we're looking at. For example, Cluster 7 is expressing [*NKG7*](https://www.genecards.org/cgi-bin/carddisp.pl?gene=NKG7&keywords=nkg7) and [*GNLY*](https://www.genecards.org/cgi-bin/carddisp.pl?gene=GNLY&keywords=GNLY), suggesting these are [NK cells](https://en.wikipedia.org/wiki/Natural_killer_cell).

To create your own plots, or use a more automated approach, the differentially expressed genes can be extracted in a convenient format with {func}`scanpy.get.rank_genes_groups_df`

In [None]:
sc.get.rank_genes_groups_df(adata, group="7").head(5)

In [None]:
dc_cluster_genes = sc.get.rank_genes_groups_df(adata, group="7").head(5)["names"]
sc.pl.umap(
    adata,
    color=[*dc_cluster_genes, "leiden_res_0.50"],
    legend_loc="on data",
    frameon=False,
    ncols=3,
)

You may have noticed that the p-values found here are extremely low. This is due to the statistical test being performed considering each cell as an independent sample. For a more conservative approach you may want to consider "pseudo-bulking" your data by sample (*e.g.* `sc.get.aggregate(adata, by=["sample", "cell_type"], func="sum", layer="counts")`) and using a more powerful differential expression tool, like [`pydeseq2`](https://pydeseq2.readthedocs.io/).

# HoloViz

In [None]:
# adata.write(filename='adata.h5ad')
adata = ad.read_h5ad('adata.h5ad')

In [None]:
from pprint import pprint

In [None]:
adata.var_names

adata 
- X : main data as n_obs × n_vars array = 17041 × 23427
- obs (observation/cell vectors as pandas Series): 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'n_genes', 'doublet_score', 'predicted_doublet', 'leiden', 'leiden_res_0.02', 'leiden_res_0.50', 'leiden_res_2.00', 'cell_type_lvl1'
- var (variable/gene vectors as pandas Series): 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
- uns (unstructured meta, config, stats as dicts): 'scrublet', 'log1p', 'hvg', 'pca', 'sample_colors', 'neighbors', 'umap', 'leiden', 'leiden_colors', 'predicted_doublet_colors', 'leiden_res_0.02', 'leiden_res_0.50', 'leiden_res_2.00', 'leiden_res_0.02_colors', 'leiden_res_0.50_colors', 'leiden_res_2.00_colors', 'rank_genes_groups', 'dendrogram_leiden_res_0.50'
- obsm (observation/cell matrices as n_obs x whatever array): 'X_pca', 'X_umap'
- varm (variable/gene matrices as n_var x whatever array): 'PCs'
- layers (array transforms, same size as main data): 'counts'
- obsp (paired observation/cell annotations as n_obs x n_obs array): 'distances', 'connectivities'
- var_names (variable/gene vectors as pandas Index): names of genes

In [None]:
adata.uns['leiden_res_0.50']

In [None]:
import xarray as xr
import hvplot.xarray
import hvplot.pandas
import holoviews as hv
import pandas as pd
import numpy as np
from holoviews.operation.datashader import datashade, rasterize
hv.extension('bokeh')


### Convert adata to xarray

In [None]:
type(adata.X)

In [None]:
import sparse

In [None]:
sparse_X = sparse.COO.from_scipy_sparse(adata.X)

In [None]:
# # Convert the sparse matrix to a dense array
# dense_data = adata.X.toarray()

# Step 1: Convert primary data (X) to an xarray DataArray
data_array = xr.DataArray(
    sparse_X,
    dims=["cells", "genes"],
    coords={
        "cells": adata.obs_names,
        "genes": adata.var_names,
        "cell_type_lvl1": ("cells", adata.obs["cell_type_lvl1"]),
    }
)

In [None]:
# Step 2: Add PCA and UMAP embeddings to the xarray dataset
embeddings = xr.Dataset(
    {
        "X_pca": (("cells", "pca_components"), adata.obsm["X_pca"]),
        "X_umap": (("cells", "umap_components"), adata.obsm["X_umap"]),
    },
    coords={"cells": adata.obs_names}
)

# Combine `data_array` and `embeddings` into one xarray Dataset
xr_dataset = xr.Dataset({"data": data_array}, coords=data_array.coords)
xr_dataset = xr_dataset.merge(embeddings)

In [None]:
xr_dataset

### Alt, Convert the obs and var to pandas df

In [None]:
adata.obsm['X_umap'].shape

In [None]:
umap_df = pd.DataFrame(adata.obsm['X_umap'], columns=['UMAP1', 'UMAP2'])
pca_df = pd.DataFrame(adata.obsm['X_pca'], columns=[f'PCA{1+i}' for i in range(adata.obsm['X_pca'].shape[-1])])

In [None]:
obs_df = adata.obs.join(umap_df.set_index(adata.obs.index))
obs_df =  obs_df.join(pca_df.set_index(adata.obs.index))
var_df = adata.var.copy()

In [None]:
# Extract expression data for marker genes
sel_genes = marker_genes['CD16+ Mono'] #["TCF7L2", "FCGR3A", "LYN"],
expression_df = pd.DataFrame(
    adata[:, sel_genes].X.toarray(), 
    columns=sel_genes, 
    index=adata.obs_names
)

In [None]:

# later try to datshade jittered scatter as overlay using .opts(jitter=)
cols2plot = ["n_genes_by_counts", "total_counts", "pct_counts_mt"]
violins = []
for col in cols2plot:
    violins.append(obs_df[col].hvplot.violin(width=300, ylabel='Value', title=col))
hv.Layout(violins).opts(shared_axes=False)

In [None]:
# FR.. hide hover tool over non-data/empty pixels

In [None]:
obs_df.hvplot.scatter(x="total_counts", y="n_genes_by_counts", color="pct_counts_mt", cmap='Viridis', rasterize=True, title='pct_counts_mt', aspect=1)

In [None]:
# bug.. cannot get legend while using rasterize or datashade
# bug.. cannot set color_key while using 

In [None]:
pca1_2_count = obs_df.hvplot.scatter(x="PCA1", y="PCA2", color="pct_counts_mt", cmap='Viridis', title='pct_counts_mt', rasterize=True, cnorm='eq_hist', aspect=1)
pca1_2_samp = obs_df.hvplot.scatter(x="PCA1", y="PCA2", by="sample", legend=True, color=['blue', 'orange'], title='sample', datashade=True, cnorm='eq_hist', aspect=1)
# color_key={'s1d1':'blue', 's1d3':'orange'},
pca1_2_samp + pca1_2_count

In [None]:
umap1_2_count = obs_df.hvplot.scatter(x="UMAP1", y="UMAP2", color="pct_counts_mt", cmap='Viridis', title='pct_counts_mt', rasterize=True, cnorm='eq_hist', aspect=1)
umap1_2_samp = obs_df.hvplot.scatter(x="UMAP1", y="UMAP2", by="sample", legend=True, color=['blue', 'orange'], title='sample', datashade=True, cnorm='eq_hist', aspect=1)
umap1_2_samp + umap1_2_count

In [None]:
import datashader as ds

In [None]:
points = hv.Points(obs_df, kdims=['UMAP1', 'UMAP2'], vdims=['leiden_res_0.50'])
umap1_2_lieden_hv = rasterize(points, aggregator=ds.by('leiden_res_0.50')).opts(
        cmap='Category20', frame_height=300, aspect=1, tools=['hover'], title='leiden, HoloViews')

umap1_2_lieden_hvplot = obs_df.hvplot.scatter(x="UMAP1", y="UMAP2", by="leiden_res_0.50", cmap='Category20', title='leiden, hvPlot', datashade=True, cnorm='eq_hist', aspect=1)
umap1_2_lieden_hv + umap1_2_lieden_hvplot

### generic dotplot

In [None]:
import pandas as pd
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

# Sample data
data = {
    'gene': ['GeneA', 'GeneA', 'GeneB', 'GeneB', 'GeneC', 'GeneC'],
    'cluster': ['Cluster1', 'Cluster2', 'Cluster1', 'Cluster2', 'Cluster1', 'Cluster2'],
    'percentage': [80, 60, 90, 50, 70, 40],  # Percentage of cells (1-100)
    'expression': [0.8, 0.3, 0.6, 0.2, 0.9, 0.4]  # Expression level (0-1)
}

df = pd.DataFrame(data)

# Map percentage to point sizes ensuring max size doesn't overlap neighboring points
# Define maximum point size (in pixels)
max_point_size = 40
df['size'] = (df['percentage'] / 100) * max_point_size

# Create the scatter plot
scatter = hv.Scatter(df, kdims=['gene', 'cluster'], vdims=['expression', 'size'])

# Customize plot
scatter = scatter.opts(
    opts.Scatter(
        xrotation=45,  # Rotate x-axis labels if necessary
        color='expression',
        cmap='Viridis',  # Choose a colormap
        size='size',
        line_color='black',  # Outline color of the points
        marker='o',
        tools=['hover'],  # Add hover tooltips
        colorbar=True,
        width=600,
        height=400,
        xlabel='gene',
        ylabel='cluster',
        show_legend=False,
    )
)

# Display the plot
hv.output(scatter)


In [None]:
sparse_X

In [None]:
obs_df.shape

In [None]:
obs_df["leiden_res_0.50"]

In [None]:
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

def compute_dotplot_data(expression, groupby, gene_names, markers, expression_cutoff):
    """
    Compute data required for creating a dot plot.

    Parameters:
    - expression: n_cells x n_genes matrix (could be sparse)
    - groupby: array of cluster assignments, length n_cells
    - gene_names: list of gene names, length n_genes
    - markers: list of gene names to include in the dot plot
    - expression_cutoff: value used to binarize expression (expression > cutoff)

    Returns:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression']
    """
    # Ensure expression is CSR format for efficient operations
    if not isinstance(expression, csr_matrix):
        expression = csr_matrix(expression)

    # Map gene names to indices
    gene_name_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}
    
    # Get indices of marker genes
    marker_indices = [gene_name_to_idx[gene] for gene in markers if gene in gene_name_to_idx]

    # Subset the expression matrix to marker genes
    expression_subset = expression[:, marker_indices]
    marker_gene_names = [gene_names[idx] for idx in marker_indices]

    # Convert groupby to numpy array if it's not already
    groupby = np.array(groupby)
    clusters = np.unique(groupby)

    # Check if all clusters are numeric
    clusters_series = pd.Series(clusters)
    clusters_numeric = pd.to_numeric(clusters_series, errors='coerce')

    if clusters_numeric.isnull().any():
        # Not all clusters are numeric
        convert_cluster_to_numeric = False
    else:
        convert_cluster_to_numeric = True

    # Initialize list to collect results
    results = []

    # Binarize the expression data using the expression_cutoff
    expression_binarized = expression_subset.copy()
    expression_binarized.data = (expression_binarized.data > expression_cutoff).astype(int)
    
    # For each cluster
    for cluster in clusters:
        # Get indices of cells in the current cluster
        cluster_mask = groupby == cluster
        cluster_cell_indices = np.where(cluster_mask)[0]
        
        # Number of cells in the cluster
        n_cells_in_cluster = len(cluster_cell_indices)
        
        # Subset the expression data to cells in the cluster
        X_cluster = expression_subset[cluster_cell_indices, :]
        X_cluster_binarized = expression_binarized[cluster_cell_indices, :]
        
        # Compute the sum of binarized expression per gene (number of cells expressing the gene)
        expressing_cells = X_cluster_binarized.sum(axis=0).A1  # Sum over cells
        
        # Compute percentage of cells expressing each gene
        percentage = (expressing_cells / n_cells_in_cluster) * 100  # Percentage
        
        # Compute mean expression per gene over all cells in cluster
        total_expression = X_cluster.sum(axis=0).A1  # Sum of expression values per gene
        mean_expression = total_expression / n_cells_in_cluster  # Mean expression
        
        # Collect results into DataFrame
        cluster_results = pd.DataFrame({
            'gene': marker_gene_names,
            'cluster': cluster,
            'percentage': percentage,
            'mean_expression': mean_expression
        })
        results.append(cluster_results)

    # Concatenate results from all clusters
    df = pd.concat(results, ignore_index=True)

    # Convert 'cluster' column to numeric if all clusters are numeric
    if convert_cluster_to_numeric:
        df['cluster'] = pd.to_numeric(df['cluster'])

    return df


In [None]:
dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, ["ID2", "PLCG2", "GNLY", "SYNE1"], .1)

In [None]:
dp_data

In [None]:
import colorcet as cc

In [None]:
def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression']
    """

    # Map percentage to point sizes
    max_point_size = max_point_size  # Adjust based on your plot dimensions
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Create the points plot
    points = hv.Points(df, kdims=['gene', 'cluster'], vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression'])
    heatmap = hv.HeatMap(df, kdims=['gene', 'cluster'], vdims=['mean_expression_normalized'])

    # Extract sorted unique cluster values
    cluster_values = sorted(df['cluster'].unique())

    # Create yticks as a list of tuples (position, label)
    yticks = [(cluster, str(cluster)) for cluster in cluster_values]

    # Determine ylim to make Y-axis tight to data extents
    # Since we have invert_yaxis=True, the ylim should be from higher to lower values
    min_cluster = min(cluster_values)
    max_cluster = max(cluster_values)
    # Assuming clusters are integers and we want to center ticks, adjust by 0.5
    ylim = (max_cluster + 0.5, min_cluster - 0.5)

    # Customize the plot
    points = points.opts(
            xrotation=90,             # Rotate x-axis labels if necessary
            color='mean_expression_normalized',
            cmap='reds',           # Choose a colormap
            size='size',
            line_color='white',       # Outline color of the points
            line_alpha=.1,
            marker='o',
            tools=['hover'],          # Add hover tooltips
            colorbar=True,
            min_height=400,
            responsive=True,
            xlabel='Gene',
            ylabel='Cluster',
            invert_yaxis=True,        # Optional: invert y-axis to match conventional heatmap orientation
            show_legend=False,
            yticks=yticks,
    )

    heatmap = heatmap.opts(cmap='greys', fill_alpha=.2, ylim=ylim)

    return heatmap * points

In [None]:
plot_dotplot(dp_data, max_point_size=15)

### Try to add a gene group labels

In [None]:
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data.
    If 'gene_group' is present in df, include an adjoined plot above the main plot to show gene groupings.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression']
          If 'gene_group' is present, includes gene groupings.
    - max_point_size: Maximum size for the dots representing 100% percentage.
    """

    # Map percentage to point sizes
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Ensure 'gene' is a categorical variable with the correct order
    genes_order = df['gene'].drop_duplicates().tolist()
    df['gene'] = pd.Categorical(df['gene'], categories=genes_order, ordered=True)

    # Create the points plot
    points = hv.Points(
        df,
        kdims=['gene', 'cluster'],
        vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression']
    )

    heatmap = hv.HeatMap(
        df,
        kdims=['gene', 'cluster'],
        vdims=['mean_expression_normalized']
    )

    # Extract sorted unique cluster values
    cluster_values = sorted(df['cluster'].unique())

    # Ensure cluster_values is not empty
    if not cluster_values:
        raise ValueError("No clusters found in the data.")

    # Create yticks as a list of tuples (position, label)
    yticks = [(cluster, str(cluster)) for cluster in cluster_values]

    # Determine ylim to make Y-axis tight to data extents
    min_cluster = min(cluster_values)
    max_cluster = max(cluster_values)
    ylim = (max_cluster + 0.5, min_cluster - 0.5)

    # Customize the points plot
    points = points.opts(
        xrotation=90,             # Rotate x-axis labels if necessary
        color='mean_expression_normalized',
        cmap='reds',              # Choose a colormap
        size='size',
        line_color='white',       # Outline color of the points
        line_alpha=0.1,
        marker='o',
        tools=['hover'],          # Add hover tooltips
        colorbar=True,
        min_height=400,
        responsive=True,
        xlabel='Gene',
        ylabel='Cluster',
        invert_yaxis=True,        # Invert y-axis to match conventional heatmap orientation
        show_legend=False,
        yticks=yticks,
        ylim=ylim,
    )

    heatmap = heatmap.opts(
        cmap='greys', 
        fill_alpha=0.2, 
        ylim=ylim,
        yticks=yticks,
    )

    # Check if 'gene_group' is in df
    if 'gene_group' in df.columns:
        # Create the adjoined plot for gene groups
        gene_groups = df[['gene', 'gene_group']].drop_duplicates()
        gene_groups = gene_groups.groupby('gene_group')['gene'].apply(list).reset_index()

        # Prepare the annotations (rectangles and labels)
        annotations = []
        x_positions = {gene: pos for pos, gene in enumerate(genes_order)}
        for _, row in gene_groups.iterrows():
            group = row['gene_group']
            genes_in_group = row['gene']
            start_gene = genes_in_group[0]
            end_gene = genes_in_group[-1]
            start_pos = x_positions[start_gene] - 0.5
            end_pos = x_positions[end_gene] + 0.5
            # Create a rectangle spanning from start_gene to end_gene
            rect = hv.Rectangles([(start_pos, 0, end_pos, 1)])
            text = hv.Text((start_pos + end_pos) / 2, 0.5, group).opts(text_align='center')
            annotations.append(rect.opts(fill_alpha=0, line_width=2, line_color='black') * text)

        # Combine annotations
        annotations = hv.Overlay(annotations).opts(
            xaxis=None,
            yaxis=None,
            show_frame=False,
            show_grid=False,
            height=50,
            ylim=(0, 1),
            xlim=(-0.5, len(genes_order) - 0.5),
            bgcolor='white',
        )

        # Layout the annotations above the main plot
        layout = (annotations + (heatmap * points)).cols(1)
        return layout

    else:
        # No gene groups, return the main plot
        return heatmap * points


In [None]:
dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, marker_genes, .1)
plot_dotplot(dp_data, max_point_size=15)

### try to make the annotations a heatmap

In [None]:
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data.
    If 'gene_group' is present in df, include annotations as an additional heatmap
    placed above the main dotplot, which has a row colored to encode the gene groups,
    and text labels for the gene groups.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression']
          If 'gene_group' is present, includes gene groupings.
    - max_point_size: Maximum size for the dots representing 100% percentage.
    """
    import pandas as pd
    import numpy as np

    # Map percentage to point sizes
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Ensure 'gene' is a categorical variable with the correct order
    genes_order = df['gene'].drop_duplicates().tolist()
    df['gene'] = pd.Categorical(df['gene'], categories=genes_order, ordered=True)

    # Ensure 'cluster' is a categorical variable with a consistent order
    clusters_order = df['cluster'].drop_duplicates().tolist()
    df['cluster'] = pd.Categorical(df['cluster'], categories=clusters_order, ordered=True)

    # Create the main heatmap and points plot
    points = hv.Points(
        df,
        kdims=['gene', 'cluster'],
        vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression']
    )

    heatmap = hv.HeatMap(
        df,
        kdims=['gene', 'cluster'],
        vdims=['mean_expression_normalized']
    )

    # Create yticks as a list of tuples (position, label)
    yticks = list(enumerate(clusters_order))

    # Customize the points plot
    points = points.opts(
        xrotation=90,
        color='mean_expression_normalized',
        cmap='reds',
        size='size',
        line_color='white',
        line_alpha=0.1,
        marker='o',
        tools=['hover'],
        colorbar=True,
        min_height=400,
        responsive=True,
        xlabel='Gene',
        ylabel='Cluster',
        invert_yaxis=True,
        show_legend=False,
        yticks=yticks,
    )

    # Customize the heatmap
    heatmap = heatmap.opts(
        cmap='greys',
        fill_alpha=0.2,
        invert_yaxis=False,
        yticks=yticks,
    )

    if 'gene_group' in df.columns:
        # Create a DataFrame for the annotations heatmap
        gene_groups = df[['gene', 'gene_group']].drop_duplicates()
        # Map gene groups to integer codes
        gene_groups['group_code'] = gene_groups['gene_group'].factorize()[0]

        # Create a DataFrame for the annotations heatmap
        annotations_df = gene_groups[['gene', 'group_code']].copy()
        annotations_df['Group'] = 'Group'  # Dummy variable for y-axis

        # Ensure 'gene' and 'Group' are categorical with the correct order
        annotations_df['gene'] = pd.Categorical(annotations_df['gene'], categories=genes_order, ordered=True)
        annotations_df['Group'] = pd.Categorical(annotations_df['Group'], categories=['Group'], ordered=True)

        # Create the annotations heatmap
        annotations_heatmap = hv.HeatMap(
            annotations_df,
            kdims=['gene', 'Group'],
            vdims=['group_code']
        )

        # Customize the annotations heatmap
        annotations_heatmap = annotations_heatmap.opts(
            colorbar=False,
            xaxis=None,
            yaxis=None,
            invert_yaxis=True,
            responsive=True,
            cmap= ['black', 'grey']*10,
            tools=['hover'],
            toolbar=None,
        )

        # Create text labels for the gene groups
        # Calculate the center position of each gene group
        gene_group_positions = gene_groups.groupby('gene_group')['gene'].apply(list)
        text_annotations = []
        for group, genes_in_group in gene_group_positions.items():
            # Calculate the center position of the group
            x_positions = [genes_order.index(gene) for gene in genes_in_group]
            x_center = (min(x_positions) + max(x_positions)) / 2
            text = hv.Text(genes_order[int(x_center)], 'Group', group).opts(
                text_align='right',
                text_baseline='middle',
                text_font_size='8pt',
                text_color='black',
                xaxis=None,
                yaxis=None,
                angle=90,
            )
            text_annotations.append(text)

        # Combine the text annotations
        text_overlay = hv.Overlay(text_annotations)

        # Overlay the text on the annotations heatmap
        annotations_plot = (annotations_heatmap * text_overlay).opts(
            xaxis=None,
            yaxis=None,
            responsive=True,
            height=50,
            toolbar=None,
            show_frame=False,
            show_grid=False,
        )

        # Stack the annotations_plot above the main plot
        layout = hv.Layout([annotations_plot, heatmap * points]).cols(1).opts(
            opts.Layout(shared_axes=True)
        )
        return layout, annotations_df

    else:
        # No gene groups, proceed without annotations
        return heatmap * points


dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, marker_genes, .1)
layout, annotations_df = plot_dotplot(dp_data, max_point_size=15)

In [None]:
annotations_df

In [None]:
layout

## Try to fix the gene grouping display

In [None]:
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

def compute_dotplot_data(expression, groupby, gene_names, markers, expression_cutoff):
    """
    Compute data required for creating a dot plot, handling markers as a list or a dictionary,
    and allowing for genes to appear in multiple groups.

    Parameters:
    - expression: n_cells x n_genes matrix (could be sparse)
    - groupby: array of cluster assignments, length n_cells
    - gene_names: list of gene names, length n_genes
    - markers: either a list of gene names or a dictionary with keys as group labels and values as lists of gene names
    - expression_cutoff: value used for binarizing expression (expression > cutoff)

    Returns:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    """
    # Ensure expression is CSR format for efficient operations
    if not isinstance(expression, csr_matrix):
        expression = csr_matrix(expression)

    # Map gene names to indices
    gene_name_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}

    # Initialize variables
    gene_group_list = []

    # Check if markers is a dictionary or a list
    if isinstance(markers, dict):
        # Markers is a dictionary with group labels
        # Flatten markers dictionary to get list of (gene, group)
        marker_genes = []
        for group, genes in markers.items():
            for gene in genes:
                if gene in gene_name_to_idx:
                    marker_genes.append((gene, group))

        # Do not remove duplicates, keep genes as they appear
        # Get the list of gene indices, allowing for duplicates
        marker_indices = [gene_name_to_idx[gene] for gene, group in marker_genes]
        marker_gene_names = [gene for gene, group in marker_genes]
        gene_groups = [group for gene, group in marker_genes]

    elif isinstance(markers, list):
        # Markers is a list of gene names
        marker_genes = [gene for gene in markers if gene in gene_name_to_idx]

        # Remove duplicates while preserving order
        marker_genes = list(dict.fromkeys(marker_genes))

        # Get indices of marker genes
        marker_indices = [gene_name_to_idx[gene] for gene in marker_genes]
        marker_gene_names = marker_genes
        gene_groups = [None] * len(marker_gene_names)

    else:
        raise ValueError("Markers must be a list or a dictionary.")

    # Convert groupby to numpy array if it's not already
    groupby = np.array(groupby)
    clusters = np.unique(groupby)

    # Check if all clusters are numeric
    clusters_series = pd.Series(clusters)
    clusters_numeric = pd.to_numeric(clusters_series, errors='coerce')

    if clusters_numeric.isnull().any():
        convert_cluster_to_numeric = False
    else:
        convert_cluster_to_numeric = True

    # Initialize list to collect results
    results = []

    # For each marker gene (could be duplicated)
    for gene_idx, gene_name, gene_group in zip(marker_indices, marker_gene_names, gene_groups):
        # Get the expression vector for this gene
        gene_expression = expression[:, gene_idx]

        # Binarize the expression data using the expression_cutoff
        gene_expression_binarized = gene_expression.copy()
        gene_expression_binarized.data = (gene_expression_binarized.data > expression_cutoff).astype(int)

        # For each cluster
        for cluster in clusters:
            # Get indices of cells in the current cluster
            cluster_mask = groupby == cluster
            cluster_cell_indices = np.where(cluster_mask)[0]

            # Number of cells in the cluster
            n_cells_in_cluster = len(cluster_cell_indices)

            # Subset the expression data to cells in the cluster
            X_cluster = gene_expression[cluster_cell_indices]
            X_cluster_binarized = gene_expression_binarized[cluster_cell_indices]

            # Compute the sum of binarized expression (number of cells expressing the gene)
            expressing_cells = X_cluster_binarized.sum()

            # Compute percentage of cells expressing the gene
            percentage = (expressing_cells / n_cells_in_cluster) * 100  # Percentage

            # Compute mean expression per gene over all cells in cluster
            total_expression = X_cluster.sum()
            mean_expression = total_expression / n_cells_in_cluster  # Mean expression

            # Collect results into DataFrames
            cluster_results = pd.DataFrame({
                'gene': [gene_name],
                'cluster': [cluster],
                'percentage': [percentage],
                'mean_expression': [mean_expression],
                'gene_group': [gene_group]
            })
            results.append(cluster_results)

    # Concatenate results from all clusters and all genes
    df = pd.concat(results, ignore_index=True)

    # # Convert 'cluster' column to numeric if all clusters are numeric
    if convert_cluster_to_numeric:
        df['cluster'] = pd.to_numeric(df['cluster'])

    return df

import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data,
    handling genes that appear multiple times due to being in multiple groups.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    - max_point_size: Maximum size for the dots representing 100% percentage.

    Returns:
    - plot: Holoviews object representing the dot plot.
    """
    import pandas as pd
    import numpy as np

    # Create a unique identifier for each gene entry
    df = df.copy()
    df['gene_id'] = df.apply(
        lambda row: f"{row['gene']} ({row['gene_group']})" if pd.notnull(row['gene_group']) else row['gene'],
        axis=1
    )

    # Map percentage to point sizes
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Ensure 'gene_id' is a categorical variable with the correct order
    gene_ids_order = df['gene_id'].drop_duplicates().tolist()
    df['gene_id'] = pd.Categorical(df['gene_id'], categories=gene_ids_order, ordered=True)

    # # Ensure 'cluster' is a categorical variable with a consistent order
    # clusters_order = df['cluster'].drop_duplicates().tolist()
    # df['cluster'] = pd.Categorical(df['cluster'], categories=clusters_order, ordered=True)

    # Extract sorted unique cluster values
    cluster_values = sorted(df['cluster'].unique())

    # Determine ylim to make Y-axis tight to data extents
    # Since we have invert_yaxis=True, the ylim should be from higher to lower values
    min_cluster = min(cluster_values)
    max_cluster = max(cluster_values)
    # Assuming clusters are integers and we want to center ticks, adjust by 0.5
    ylim = (max_cluster + 0.5, min_cluster - 0.5)

    # Create yticks as a list of tuples (position, label)
    yticks = [(cluster, str(cluster)) for cluster in cluster_values]

    # Create the main heatmap and points plot
    points = hv.Points(
        df,
        kdims=['gene_id', 'cluster'],
        vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression', 'gene_group']
    )

    # heatmap = hv.HeatMap(
    #     df,
    #     kdims=['gene_id', 'cluster'],
    #     vdims=['mean_expression_normalized']
    # )

    # Create yticks as a list of tuples (index, label)
    # yticks = list(enumerate(clusters_order))

    # Customize the points plot
    points = points.opts(
        xrotation=90,
        color='mean_expression_normalized',
        cmap='Reds',
        size='size',
        line_color='black',
        line_alpha=0.1,
        marker='o',
        tools=['hover'],
        colorbar=True,
        min_height=400,
        responsive=True,
        xlabel='Gene',
        ylabel='Cluster',
        invert_yaxis=False,
        show_legend=False,
        yticks=yticks,
        ylim=ylim,
    )

    # # Customize the heatmap
    # heatmap = heatmap.opts(
    #     cmap='Greys',
    #     fill_alpha=0.2,
    #     invert_yaxis=True,
    #     yticks=yticks,
    #     ylim=ylim,
    # )

    if 'gene_group' in df.columns:
        # Create a DataFrame for the annotations heatmap
        gene_groups = df[['gene_id', 'gene_group']].drop_duplicates()
        # Map gene groups to integer codes
        gene_groups['group_code'] = gene_groups['gene_group'].factorize()[0]

        # Create a DataFrame for the annotations heatmap
        annotations_df = gene_groups[['gene_id', 'group_code', 'gene_group']].copy()
        annotations_df['Group'] = 'Group'  # Dummy variable for y-axis

        # Ensure 'gene_id' and 'Group' are categorical with the correct order
        annotations_df['gene_id'] = pd.Categorical(annotations_df['gene_id'], categories=gene_ids_order, ordered=True)
        annotations_df['Group'] = pd.Categorical(annotations_df['Group'], categories=['Group'], ordered=True)

        # Create the annotations heatmap
        annotations_heatmap = hv.HeatMap(
            annotations_df,
            kdims=['gene_id', 'Group'],
            vdims=['group_code', 'gene_group']
        )

        # Customize the annotations heatmap
        annotations_heatmap = annotations_heatmap.opts(
            colorbar=False,
            xaxis=None,
            yaxis=None,
            responsive=True,
            cmap='glasbey_hv',
            tools=['hover'],
            toolbar=None,
            height=50,
            show_frame=False,
            show_grid=False,
        )

        # # Create text labels for the gene groups
        # # Calculate the center position of each gene group
        # gene_group_positions = gene_groups.groupby('gene_group')['gene_id'].apply(list)
        # text_annotations = []
        # for group, gene_ids_in_group in gene_group_positions.items():
        #     # Calculate the center position of the group
        #     x_positions = [gene_ids_order.index(gene_id) for gene_id in gene_ids_in_group]
        #     x_center = (min(x_positions) + max(x_positions)) / 2
        #     text = hv.Text(gene_ids_order[int(x_center)], 'Group', group).opts(
        #         text_align='right',
        #         text_baseline='middle',
        #         text_font_size='8pt',
        #         text_color='black',
        #         xaxis=None,
        #         yaxis=None,
        #         angle=90,
        #     )
        #     text_annotations.append(text)

        # # Combine the text annotations
        # text_overlay = hv.Overlay(text_annotations)

        # # Overlay the text on the annotations heatmap
        # annotations_plot = (annotations_heatmap * text_overlay).opts(
        #     xaxis=None,
        #     yaxis=None,
        #     responsive=True,
        #     height=50,
        #     toolbar=None,
        #     show_frame=False,
        #     show_grid=False,
        # )

        # Stack the annotations_plot above the main plot
        layout = hv.Layout([annotations_heatmap, points]).cols(1).opts(
            opts.Layout(shared_axes=True)
        )
        return layout

    else:
        # No gene groups, proceed without annotations
        return heatmap * points

dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, marker_genes, .1)
plot_dotplot(dp_data, max_point_size=12)

### Without comments

In [None]:
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

def compute_dotplot_data(expression, groupby, gene_names, markers, expression_cutoff):
    """
    Compute data required for creating a dot plot, handling markers as a list or a dictionary,
    and allowing for genes to appear in multiple groups.

    Parameters:
    - expression: n_cells x n_genes matrix (could be sparse)
    - groupby: array of cluster assignments, length n_cells
    - gene_names: list of gene names, length n_genes
    - markers: either a list of gene names or a dictionary with keys as group labels and values as lists of gene names
    - expression_cutoff: value used for binarizing expression (expression > cutoff)

    Returns:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    """
    # Ensure expression is CSR format for efficient operations
    if not isinstance(expression, csr_matrix):
        expression = csr_matrix(expression)

    # Map gene names to indices
    gene_name_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}

    # Initialize variables
    gene_group_list = []

    # Check if markers is a dictionary or a list
    if isinstance(markers, dict):
        # Markers is a dictionary with group labels
        # Flatten markers dictionary to get list of (gene, group)
        marker_genes = []
        for group, genes in markers.items():
            for gene in genes:
                if gene in gene_name_to_idx:
                    marker_genes.append((gene, group))

        # Do not remove duplicates, keep genes as they appear
        # Get the list of gene indices, allowing for duplicates
        marker_indices = [gene_name_to_idx[gene] for gene, group in marker_genes]
        marker_gene_names = [gene for gene, group in marker_genes]
        gene_groups = [group for gene, group in marker_genes]

    elif isinstance(markers, list):
        # Markers is a list of gene names
        marker_genes = [gene for gene in markers if gene in gene_name_to_idx]

        # Remove duplicates while preserving order
        marker_genes = list(dict.fromkeys(marker_genes))

        # Get indices of marker genes
        marker_indices = [gene_name_to_idx[gene] for gene in marker_genes]
        marker_gene_names = marker_genes
        gene_groups = [None] * len(marker_gene_names)

    else:
        raise ValueError("Markers must be a list or a dictionary.")

    # Convert groupby to numpy array if it's not already
    groupby = np.array(groupby)
    clusters = np.unique(groupby)

    # Check if all clusters are numeric
    clusters_series = pd.Series(clusters)
    clusters_numeric = pd.to_numeric(clusters_series, errors='coerce')

    if clusters_numeric.isnull().any():
        convert_cluster_to_numeric = False
    else:
        convert_cluster_to_numeric = True

    # Initialize list to collect results
    results = []

    # For each marker gene (could be duplicated)
    for gene_idx, gene_name, gene_group in zip(marker_indices, marker_gene_names, gene_groups):
        # Get the expression vector for this gene
        gene_expression = expression[:, gene_idx]

        # Binarize the expression data using the expression_cutoff
        gene_expression_binarized = gene_expression.copy()
        gene_expression_binarized.data = (gene_expression_binarized.data > expression_cutoff).astype(int)

        # For each cluster
        for cluster in clusters:
            # Get indices of cells in the current cluster
            cluster_mask = groupby == cluster
            cluster_cell_indices = np.where(cluster_mask)[0]

            # Number of cells in the cluster
            n_cells_in_cluster = len(cluster_cell_indices)

            # Subset the expression data to cells in the cluster
            X_cluster = gene_expression[cluster_cell_indices]
            X_cluster_binarized = gene_expression_binarized[cluster_cell_indices]

            # Compute the sum of binarized expression (number of cells expressing the gene)
            expressing_cells = X_cluster_binarized.sum()

            # Compute percentage of cells expressing the gene
            percentage = (expressing_cells / n_cells_in_cluster) * 100  # Percentage

            # Compute mean expression per gene over all cells in cluster
            total_expression = X_cluster.sum()
            mean_expression = total_expression / n_cells_in_cluster  # Mean expression

            # Collect results into DataFrames
            cluster_results = pd.DataFrame({
                'gene': [gene_name],
                'cluster': [cluster],
                'percentage': [percentage],
                'mean_expression': [mean_expression],
                'gene_group': [gene_group]
            })
            results.append(cluster_results)

    # Concatenate results from all clusters and all genes
    df = pd.concat(results, ignore_index=True)

    # Convert 'cluster' column to numeric if all clusters are numeric
    if convert_cluster_to_numeric:
        df['cluster'] = pd.to_numeric(df['cluster'])

    return df

def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data,
    handling genes that appear multiple times due to being in multiple groups.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    - max_point_size: Maximum size for the dots representing 100% percentage.

    Returns:
    - plot: Holoviews object representing the dot plot.
    """
    import pandas as pd
    import numpy as np

    # Create a unique identifier for each gene entry
    df = df.copy()
    df['gene_id'] = df.apply(
        lambda row: f"{row['gene']} ({row['gene_group']})" if pd.notnull(row['gene_group']) else row['gene'],
        axis=1
    )

    # Map percentage to point sizes
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Ensure 'gene_id' is a categorical variable with the correct order
    gene_ids_order = df['gene_id'].drop_duplicates().tolist()
    df['gene_id'] = pd.Categorical(df['gene_id'], categories=gene_ids_order, ordered=True)

    # Extract sorted unique cluster values
    cluster_values = sorted(df['cluster'].unique())

    # Determine ylim to make Y-axis tight to data extents
    min_cluster = min(cluster_values)
    max_cluster = max(cluster_values)
    ylim = (max_cluster + 0.5, min_cluster - 0.5)

    # Create yticks as a list of tuples (position, label)
    yticks = [(cluster, str(cluster)) for cluster in cluster_values]

    # Create the main points plot
    points = hv.Points(
        df,
        kdims=['gene_id', 'cluster'],
        vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression', 'gene_group']
    )

    # Customize the points plot
    points = points.opts(
        xrotation=90,
        color='mean_expression_normalized',
        cmap='Reds',
        size='size',
        line_color='black',
        line_alpha=0.1,
        marker='o',
        tools=['hover'],
        colorbar=True,
        min_height=400,
        responsive=True,
        xlabel='Gene',
        ylabel='Cluster',
        invert_yaxis=False,
        show_legend=False,
        yticks=yticks,
        ylim=ylim,
    )

    if 'gene_group' in df.columns:
        # Create a DataFrame for the annotations heatmap
        gene_groups = df[['gene_id', 'gene_group']].drop_duplicates()
        # Map gene groups to integer codes
        gene_groups['group_code'] = gene_groups['gene_group'].factorize()[0]

        # Create a DataFrame for the annotations heatmap
        annotations_df = gene_groups[['gene_id', 'group_code', 'gene_group']].copy()
        annotations_df['Group'] = 'Group'  # Dummy variable for y-axis

        # Ensure 'gene_id' and 'Group' are categorical with the correct order
        annotations_df['gene_id'] = pd.Categorical(annotations_df['gene_id'], categories=gene_ids_order, ordered=True)
        annotations_df['Group'] = pd.Categorical(annotations_df['Group'], categories=['Group'], ordered=True)

        # Create the annotations heatmap
        annotations_heatmap = hv.HeatMap(
            annotations_df,
            kdims=['gene_id', 'Group'],
            vdims=['group_code', 'gene_group']
        )

        # Customize the annotations heatmap
        annotations_heatmap = annotations_heatmap.opts(
            colorbar=False,
            xaxis=None,
            yaxis=None,
            responsive=True,
            cmap='glasbey_hv',
            tools=['hover'],
            toolbar=None,
            height=50,
            show_frame=False,
            show_grid=False,
        )

        # Stack the annotations_heatmap above the main plot
        layout = hv.Layout([annotations_heatmap, points]).cols(1).opts(
            opts.Layout(shared_axes=True)
        )
        return layout

    else:
        # No gene groups, proceed without annotations
        return points


In [None]:
dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, marker_genes, .1)
plot_dotplot(dp_data, max_point_size=12)

### With dendrogram

In [None]:
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import pdist
import holoviews as hv
from holoviews import opts
import panel as pn
pn.extension('gridstack')
hv.extension('bokeh')

def compute_dotplot_data(expression, groupby, gene_names, markers, expression_cutoff):
    """
    Compute data required for creating a dot plot, handling markers as a list or a dictionary,
    and allowing for genes to appear in multiple groups.

    Parameters:
    - expression: n_cells x n_genes matrix (could be sparse)
    - groupby: array of cluster assignments, length n_cells
    - gene_names: list of gene names, length n_genes
    - markers: either a list of gene names or a dictionary with keys as group labels and values as lists of gene names
    - expression_cutoff: value used for binarizing expression (expression > cutoff)

    Returns:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    """
    # Ensure expression is CSR format for efficient operations
    if not isinstance(expression, csr_matrix):
        expression = csr_matrix(expression)

    # Map gene names to indices
    gene_name_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}

    # Initialize variables
    gene_group_list = []

    # Check if markers is a dictionary or a list
    if isinstance(markers, dict):
        # Markers is a dictionary with group labels
        # Flatten markers dictionary to get list of (gene, group)
        marker_genes = []
        for group, genes in markers.items():
            for gene in genes:
                if gene in gene_name_to_idx:
                    marker_genes.append((gene, group))

        # Do not remove duplicates, keep genes as they appear
        # Get the list of gene indices, allowing for duplicates
        marker_indices = [gene_name_to_idx[gene] for gene, group in marker_genes]
        marker_gene_names = [gene for gene, group in marker_genes]
        gene_groups = [group for gene, group in marker_genes]

    elif isinstance(markers, list):
        # Markers is a list of gene names
        marker_genes = [gene for gene in markers if gene in gene_name_to_idx]

        # Remove duplicates while preserving order
        marker_genes = list(dict.fromkeys(marker_genes))

        # Get indices of marker genes
        marker_indices = [gene_name_to_idx[gene] for gene in marker_genes]
        marker_gene_names = marker_genes
        gene_groups = [None] * len(marker_gene_names)

    else:
        raise ValueError("Markers must be a list or a dictionary.")

    # Convert groupby to numpy array if it's not already
    groupby = np.array(groupby)
    clusters = np.unique(groupby)

    # Check if all clusters are numeric
    clusters_series = pd.Series(clusters)
    clusters_numeric = pd.to_numeric(clusters_series, errors='coerce')

    if clusters_numeric.isnull().any():
        convert_cluster_to_numeric = False
    else:
        convert_cluster_to_numeric = True

    # Initialize list to collect results
    results = []

    # For each marker gene (could be duplicated)
    for gene_idx, gene_name, gene_group in zip(marker_indices, marker_gene_names, gene_groups):
        # Get the expression vector for this gene
        gene_expression = expression[:, gene_idx]

        # Binarize the expression data using the expression_cutoff
        gene_expression_binarized = gene_expression.copy()
        gene_expression_binarized.data = (gene_expression_binarized.data > expression_cutoff).astype(int)

        # For each cluster
        for cluster in clusters:
            # Get indices of cells in the current cluster
            cluster_mask = groupby == cluster
            cluster_cell_indices = np.where(cluster_mask)[0]

            # Number of cells in the cluster
            n_cells_in_cluster = len(cluster_cell_indices)

            # Subset the expression data to cells in the cluster
            X_cluster = gene_expression[cluster_cell_indices]
            X_cluster_binarized = gene_expression_binarized[cluster_cell_indices]

            # Compute the sum of binarized expression (number of cells expressing the gene)
            expressing_cells = X_cluster_binarized.sum()

            # Compute percentage of cells expressing the gene
            percentage = (expressing_cells / n_cells_in_cluster) * 100  # Percentage

            # Compute mean expression per gene over all cells in cluster
            total_expression = X_cluster.sum()
            mean_expression = total_expression / n_cells_in_cluster  # Mean expression

            # Collect results into DataFrames
            cluster_results = pd.DataFrame({
                'gene': [gene_name],
                'cluster': [cluster],
                'percentage': [percentage],
                'mean_expression': [mean_expression],
                'gene_group': [gene_group]
            })
            results.append(cluster_results)

    # Concatenate results from all clusters and all genes
    df = pd.concat(results, ignore_index=True)

    # Convert 'cluster' column to numeric if all clusters are numeric
    if convert_cluster_to_numeric:
        df['cluster'] = pd.to_numeric(df['cluster'])

    return df

def plot_dotplot(df, max_point_size=40):
    """
    Create a dot plot from the DataFrame output by compute_dotplot_data,
    handling genes that appear multiple times due to being in multiple groups,
    and adding a dendrogram to the right of the main plot.

    Parameters:
    - df: DataFrame with columns ['gene', 'cluster', 'percentage', 'mean_expression', 'gene_group']
    - max_point_size: Maximum size for the dots representing 100% percentage.

    Returns:
    - plot: Holoviews object representing the dot plot with a dendrogram.
    """
    import pandas as pd
    import numpy as np
    from scipy.cluster.hierarchy import linkage, dendrogram
    from scipy.spatial.distance import pdist
    import holoviews as hv

    # Create a unique identifier for each gene entry
    df = df.copy()
    df['gene_id'] = df.apply(
        lambda row: f"{row['gene']} ({row['gene_group']})" if pd.notnull(row['gene_group']) else row['gene'],
        axis=1
    )

    # Map percentage to point sizes
    df['size'] = (df['percentage'] / df['percentage'].max()) * max_point_size

    # Normalize mean_expression for color mapping
    df['mean_expression_normalized'] = df['mean_expression'] / df['mean_expression'].max()

    # Ensure 'gene_id' is a categorical variable with the correct order
    gene_ids_order = df['gene_id'].drop_duplicates().tolist()
    df['gene_id'] = pd.Categorical(df['gene_id'], categories=gene_ids_order, ordered=True)

    # Create the cluster-gene matrix for clustering
    cluster_gene_matrix = df.pivot_table(index='cluster', columns='gene_id', values='mean_expression', fill_value=0)

    # Perform hierarchical clustering on clusters
    # Compute the distance matrix
    X = cluster_gene_matrix.values
    cluster_dist = pdist(X, metric='euclidean')

    # Perform hierarchical clustering
    cluster_linkage = linkage(cluster_dist, method='average')

    # Create a dendrogram plot without plotting
    dendro_data = dendrogram(cluster_linkage, labels=cluster_gene_matrix.index, no_plot=True)

    # Get the order of clusters from dendrogram
    clusters_ordered = dendro_data['ivl']

    # Update df['cluster'] to be categorical with the new order
    df['cluster'] = pd.Categorical(df['cluster'], categories=clusters_ordered, ordered=True)

    # Map cluster labels to positions (for y-axis)
    cluster_positions = {cluster: pos for pos, cluster in enumerate(clusters_ordered)}
    df['cluster_pos'] = df['cluster'].map(cluster_positions)

    # Create yticks as a list of tuples (position, label)
    # yticks = [(pos, str(cluster)) for pos, cluster in enumerate(clusters_ordered)]

    # Extract sorted unique cluster values
    cluster_values = sorted(df['cluster'].unique())

    # Determine ylim to make Y-axis tight to data extents
    min_cluster = min(cluster_values)
    max_cluster = max(cluster_values)
    ylim = (max_cluster + 0.5, min_cluster - 0.5)

    # Create yticks as a list of tuples (position, label)
    yticks = [(cluster, str(cluster)) for cluster in cluster_values]

    # Create the main points plot
    points = hv.Points(
        df,
        kdims=['gene_id', 'cluster_pos'],
        vdims=['mean_expression_normalized', 'size', 'percentage', 'mean_expression', 'gene_group']
    )

    # Customize the points plot
    points = points.opts(
        xrotation=90,
        color='mean_expression_normalized',
        cmap='Reds',
        size='size',
        line_color='black',
        line_alpha=0.1,
        marker='o',
        tools=['hover'],
        colorbar=True,
        width=1000,
        frame_height=400,
        # responsive=True,
        xlabel='Gene',
        ylabel='Cluster',
        invert_yaxis=False,
        show_legend=False,
        yticks=yticks,
        ylim=ylim,
    )

    if 'gene_group' in df.columns:
        # Create a DataFrame for the annotations heatmap
        gene_groups = df[['gene_id', 'gene_group']].drop_duplicates()
        # Map gene groups to integer codes
        gene_groups['group_code'] = gene_groups['gene_group'].factorize()[0]

        # Create a DataFrame for the annotations heatmap
        annotations_df = gene_groups[['gene_id', 'group_code', 'gene_group']].copy()
        annotations_df['Group'] = 'Group'  # Dummy variable for y-axis

        # Ensure 'gene_id' and 'Group' are categorical with the correct order
        annotations_df['gene_id'] = pd.Categorical(annotations_df['gene_id'], categories=gene_ids_order, ordered=True)
        annotations_df['Group'] = pd.Categorical(annotations_df['Group'], categories=['Group'], ordered=True)

        # Create the annotations heatmap
        annotations_heatmap = hv.HeatMap(
            annotations_df,
            kdims=['gene_id', 'Group'],
            vdims=['group_code', 'gene_group']
        )

        # Customize the annotations heatmap
        annotations_heatmap = annotations_heatmap.opts(
            colorbar=False,
            xaxis=None,
            yaxis=None,
            responsive=False,
            width=600,
            cmap='glasbey_hv',
            tools=['hover'],
            toolbar=None,
            height=50,
            show_frame=False,
            show_grid=False,
        )

        # Stack the annotations_heatmap above the main plot
        main_plot = hv.Layout([annotations_heatmap, hv.Empty(), points]).cols(2).opts(
            opts.Layout(shared_axes=True)
        )
    else:
        main_plot = points

    # Map leaf positions to cluster positions
    leaf_positions = {int(leaf_id): cluster_positions[clusters_ordered[int(leaf_id)]] for leaf_id in dendro_data['leaves']}

    # Adjust coordinates to match the cluster positions
    dendro_paths = []
    icoord = np.array(dendro_data['dcoord'])  # Swapped
    dcoord = np.array(dendro_data['icoord'])  # Swapped

    for xs, ys in zip(icoord, dcoord):
        ys_new = []
        for y in ys:
            if y % 10 == 5.0:
                # Leaf node
                leaf_id = int((y - 5.0) / 10.0)
                ys_new.append(leaf_positions[leaf_id])
            else:
                # Internal node
                ys_new.append(y / max(dcoord.flatten()) * (len(clusters_ordered) - 1))
        dendro_paths.append(np.column_stack([xs, ys_new]))

    # Create the dendrogram plot
    dendrogram_plot = hv.Path(dendro_paths, ['Distance', 'Cluster'])

    # Customize dendrogram
    dendrogram_plot = dendrogram_plot.opts(
        width=200,
        frame_height=400,
        xlabel='',
        # ylabel='',
        invert_yaxis=False,
        xaxis=None,
        yaxis='right',
        # yaxis=None,
        show_frame=False,
        # show_grid=False,
        fontsize={'labels': '8pt'},
        tools=['hover'],
        yticks=yticks,
        ylim=ylim,
        
    )

    return (points + dendrogram_plot).opts(shared_axes=True)


In [None]:
dp_data = compute_dotplot_data(adata.X, adata.obs['leiden_res_0.50'], adata.var_names, marker_genes, 0.1)
plot = plot_dotplot(dp_data, max_point_size=12)

In [None]:
plot