# Feature Map Workflow

This is meant to be a walkthrough to create a UMAP-style application for anndata.

Plenty left to do.. just getting some of the functionality in place 

In [None]:
import scanpy as sc
import anndata as ad
import pooch

import holoviews.operation.datashader as hd
import datashader as ds
import colorcet as cc
import panel as pn
from panel.io import hold
import numpy as np
import holoviews as hv

pn.extension()
hv.extension('bokeh')

## Ingest Data

Data from bone marrow mononuclear cells of healthy human donors. The samples used were measured using the 10X Multiome Gene Expression and Chromatin Accessability kit. 

In [None]:
EXAMPLE_DATA = pooch.create(
    path=pooch.os_cache("scverse_tutorials"),
    base_url="doi:10.6084/m9.figshare.22716739.v1/",
)
EXAMPLE_DATA.load_registry_from_doi()

In [None]:
%%time

samples = {
    "s1d1": "s1d1_filtered_feature_bc_matrix.h5",
    "s1d3": "s1d3_filtered_feature_bc_matrix.h5",
}
adatas = {}

for sample_id, filename in samples.items():
    path = EXAMPLE_DATA.fetch(filename)
    sample_adata = sc.read_10x_h5(path)
    sample_adata.var_names_make_unique()
    adatas[sample_id] = sample_adata

adata = ad.concat(adatas, label="sample")
adata.obs_names_make_unique()
print(adata.obs["sample"].value_counts())
adata

The data contains ~8,000 cells per sample and 36,601 measured genes.

## Data Preprocessing

### Common Quality Control Metrics

One can pass specific gene population to {func}~scanpy.pp.calculate_qc_metrics in order to calculate proportions of counts for these populations. Mitochondrial, ribosomal and hemoglobin genes are defined by distinct prefixes as listed below.

In [None]:
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith("MT-")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains("^HB[^(P)]")

In [None]:
sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
)

### Filter by cells and genes

We filter cells with less than 100 genes expressed and genes that are detected in less than 3 cells. 

In [None]:
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)

### Remove doublets

In [None]:
sc.pp.scrublet(adata, batch_key="sample")

### Count Depth Scaling Normalization

We are applying median count depth normalization with log1p transformation (AKA log1PF).


In [None]:
# Saving count data
adata.layers["counts"] = adata.X.copy()

In [None]:
# Normalizing to median total counts
sc.pp.normalize_total(adata)
# Logarithmize the data
sc.pp.log1p(adata)

### Dimensionality Reduction and Feature Selection

Reduce the dimensionality and only include the most informative genese.

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=2000, batch_key="sample")

Reduce the dimensionality of the data by running principal component analysis (PCA), which reveals the main axes of variation and denoises the data.

In [None]:
sc.tl.pca(adata)

Inspect the contribution of single PCs to the total variance in the data. This gives us information about how many PCs we should consider in order to compute the neighborhood relations of cells.

In [None]:
sc.pl.pca_variance_ratio(adata, n_pcs=50, log=True)

## Nearest neighbor graph constuction

Let us compute the neighborhood graph of cells using the PCA representation of the data matrix.

In [None]:
sc.pp.neighbors(adata)

This graph can then be embedded in two dimensions for visualiztion with UMAP (McInnes et al., 2018):

In [None]:
sc.tl.umap(adata)

We can now visualize the UMAP according to the `sample`. 

## Clustering

Use the Leiden graph-clustering method (community detection based on optimizing modularity) {cite}`Traag2019`. Leiden clustering directly clusters the neighborhood graph of cells, which we already computed in the previous section.

In [None]:
%%time

# Note: Using the `igraph` implementation and a fixed number of iterations 
# can be significantly faster, especially for larger datasets
sc.tl.leiden(adata, flavor="igraph", n_iterations=2)

## Build a Feature Map Explorer

We'll build it step-by-step, starting with the basic components and gradually adding more functionality.

### Creating the Basic Plot Function

Let's start by creating a function to generate a simple dimensional reduction plot. This function will take data and create a scatter plot:

In [None]:
def create_basic_featuremap(x_data, x_dim, y_dim, xaxis_label, yaxis_label, width=300, height=300):
    """
    Create a basic scatter plot of dimensional reduction data
    
    Parameters:
    - x_data: numpy.ndarray, shape n_obs by n_dimensions
    - x_dim, y_dim: int, indices into x_data's dimension to use as x or y data
    - xaxis_label, yaxis_label: str, labels for the axes
    - width, height: int, dimensions of the plot
    """
    plot = hv.Points(
        (x_data[:, x_dim], x_data[:, y_dim]),
        [xaxis_label, yaxis_label]
    )
    
    plot = plot.opts(
        size=1,
        alpha=0.5,
        tools=['hover'],
        frame_width=width,
        frame_height=height
    )
    
    return plot

pca_data = adata.obsm['X_pca']
basic_plot = create_basic_featuremap(
    pca_data,
    x_dim=0,
    y_dim=1,
    xaxis_label='PC1',
    yaxis_label='PC2',
    width=300,
    height=300
)
basic_plot

### Adding Color Support
Now, let's enhance our plot by adding color to represent different variables:

### Adding Datashader Support
For large datasets, we need datashader to efficiently render the points:

### Adding Labels Support
Now, let's add the ability to display labels at the median position of each category

### Creating an Interactive App
Finally, let's put it all together, creatign an interactive application with Panel:

In [None]:
def create_featuremap_plot(
    x_data, color_data, x_dim, y_dim, color_var, xaxis_label, yaxis_label,
    width=300, height=300, datashading=False, labels=False,
    cont_cmap = 'Viridis',
    cat_cmap = cc.b_glasbey_category10):
    """
    Create a cluster points plot

    Parameters:
    - x_data: numpy.ndarray, shape n_obs by n_clusters
    - x_dim, y_dim: int, indices into x_data's cluster dim to use as x or y data.
    - color_data: numpy.ndarray, shape n_obs color values (categorical or continuous).
    - color_var: str, the name to give the coloring dim.
    - xaxis_label, yaxis_label: str, labels for the axes.
    - width, height: int, dimensions of the plot.
    - datashading: bool, whether to apply Datashader.
    - labels: bool, whether to overlay labels at median positions.

    """
    is_categorical = (
        color_data.dtype.name in ['category', 'categorical', 'bool'] or
        np.issubdtype(color_data.dtype, np.object_) or
        np.issubdtype(color_data.dtype, np.str_)
    )
    
    if is_categorical:
        n_unq_cat = len(np.unique(color_data))
        cmap = cat_cmap[:n_unq_cat]
        colorbar = False
        if labels:
            show_legend = False
        else:
            show_legend = True
    else:
        cmap = cont_cmap
        show_legend = False
        colorbar = True

    plot = hv.Points(
        (x_data[:, x_dim], x_data[:, y_dim], color_data),
        [xaxis_label, yaxis_label], color_var
    )

    plot_opts = dict(
        color=color_var,
        cmap=cmap,
        size=1,
        alpha=0.5,
        colorbar=colorbar,
        padding=0,
        tools=['hover'],
        show_legend=show_legend,
        legend_position='right',
    )

    label_opts = dict(
        text_font_size='8pt',
        text_color='black'
    )

    if datashading:
        if is_categorical:
            aggregator = ds.count_cat(color_var)
            plot = hd.rasterize(plot, aggregator=aggregator)
            plot = hd.dynspread(plot, threshold=0.5)
            plot = plot.opts(cmap=cmap, tools=['hover'])

            if labels:
                unique_categories = np.unique(color_data)
                labels_data = []
                for cat in unique_categories:
                    mask = color_data == cat
                    median_x = np.median(x_data[mask, x_dim])
                    median_y = np.median(x_data[mask, y_dim])
                    labels_data.append((median_x, median_y, str(cat)))
                labels_element = hv.Labels(labels_data, [xaxis_label, yaxis_label], 'Label').opts(**label_opts)
                plot = plot * labels_element
            else:
                # Create a custom legend with dummy scatter plot hack
                unique_categories = np.unique(color_data)
                color_key = dict(zip(unique_categories, cmap[:len(unique_categories)]))
                legend_items = [
                    hv.Points([0,0], label=str(cat)).opts(
                        color=color_key[cat],
                        size=0
                    ) for cat in unique_categories
                ]
                legend = hv.NdOverlay({str(cat): item for cat, item in zip(unique_categories, legend_items)}).opts(
                    show_legend=True,
                    legend_position='right',
                    legend_limit=1000,
                    legend_cols= len(unique_categories) // 8,
                )
                plot = plot * legend

        else:
            aggregator = ds.mean(color_var)
            plot = hd.rasterize(plot, aggregator=aggregator)
            plot = hd.dynspread(plot, threshold=0.5)
            plot = plot.opts(cmap=cmap, colorbar=colorbar)
    else:
        plot = plot.opts(**plot_opts)
        if is_categorical and labels:
            unique_categories = np.unique(color_data)
            labels_data = []
            for cat in unique_categories:
                mask = color_data == cat
                median_x = np.median(x_data[mask, x_dim])
                median_y = np.median(x_data[mask, y_dim])
                labels_data.append((median_x, median_y, str(cat)))
            labels_element = (hv.Labels(labels_data, [xaxis_label, yaxis_label], 'Label')
                              .opts(**label_opts)
                             )
            plot = plot * labels_element

    return plot.opts(
        title=f"{color_var}",
        tools=['hover'],
        show_legend=show_legend,
        frame_width=width,
        frame_height=height
    )



def layout_dimreduction_plots(adata, color, dimensions, width=200, height=200, dim_reduction='X_pca', labels=False):
    dr_key = dim_reduction
    dr_label = dr_key.split('_')[1].upper()
    plots = []
    x_data = adata.obsm[dr_key]
    for color_var, (x_dim, y_dim) in zip(color, dimensions):
        try: # color per derived obs
            color_data = adata.obs[color_var].values
        except: # color per gene expression
            color_data = adata.X.getcol(adata.var_names.get_loc(color_var)).toarray().flatten()
        plot = create_featuremap_plot(
            x_data, color_data, x_dim, y_dim, color_var,
            f'{dr_label}{x_dim + 1}', f'{dr_label}{y_dim + 1}',
            width=width,
            height=height,
            datashading=True,
            labels=labels,
        )
        plots.append(plot)
    return hv.Layout(plots).opts(shared_axes=False, axiswise=True)

def featuremap_app(
    adata,
    dim_reduction=None,
    color_by=None,
    datashade=True,
    show_widgets=True,
    width=200,
    height=200,
    labels=False,
):
    
    color_options = list(adata.obs.columns)
    default_color = color_by or color_options[0]

    # Map dimension reduction methods to their labels
    dr_options = {key: key.split('_')[1].upper() for key in adata.obsm.keys()}
    default_dr = dim_reduction or list(dr_options.keys())[0]

    x_data = adata.obsm[default_dr]
    num_dims = x_data.shape[1]

    dr_label = dr_options[default_dr]
    dim_labels = [f"{dr_label}{i+1}" for i in range(num_dims)]
    dim_mapping = {label: index for index, label in enumerate(dim_labels)}

    dr_select = pn.widgets.Select(
        name='Dim Reduction', options=list(dr_options.keys()), value=default_dr
    )
    xaxis = pn.widgets.Select(name='X-axis', options=dim_labels, value=dim_labels[0])
    yaxis = pn.widgets.Select(name='Y-axis', options=dim_labels, value=dim_labels[1])
    color = pn.widgets.Select(name='Color By', options=color_options, value=default_color)
    datashading_switch = pn.widgets.Checkbox(name='Enable Datashader', value=datashade)

    def update_plot(dr_select_value, xaxis_value, yaxis_value, color_value, datashading_value):
        x_data = adata.obsm[dr_select_value]
        dr_label = dr_options[dr_select_value]
        num_dims = x_data.shape[1]
        dim_labels = [f"{dr_label}{i+1}" for i in range(num_dims)]
        dim_mapping = {label: index for index, label in enumerate(dim_labels)}

        x_dim = dim_mapping[xaxis_value]
        y_dim = dim_mapping[yaxis_value]
        color_data = adata.obs[color_value].values

        return create_featuremap_plot(
            x_data,
            color_data,
            x_dim,
            y_dim,
            color_value,
            xaxis_value,
            yaxis_value,
            width=width,
            height=height,
            datashading=datashading_value,
            labels=labels,
        )

    plot_pane = pn.bind(
        update_plot,
        dr_select_value=dr_select,
        xaxis_value=xaxis,
        yaxis_value=yaxis,
        color_value=color,
        datashading_value=datashading_switch,
    )

    @hold()
    def update_axis_options(event):
        x_data = adata.obsm[event.new]
        num_dims = x_data.shape[1]
        dr_label = dr_options[event.new]
        new_dim_labels = [f"{dr_label}{i+1}" for i in range(num_dims)]
        xaxis.options = new_dim_labels
        yaxis.options = new_dim_labels
        xaxis.value = new_dim_labels[0]
        yaxis.value = new_dim_labels[1]

    dr_select.param.watch(update_axis_options, 'value')

    # Ensure x-axis and y-axis selections are different
    def enforce_different_axes(event):
        if xaxis.value == yaxis.value:
            available_options = [opt for opt in yaxis.options if opt != xaxis.value]
            if available_options:
                yaxis.value = available_options[0]

    xaxis.param.watch(enforce_different_axes, 'value')
    yaxis.param.watch(enforce_different_axes, 'value')

    widgets = pn.WidgetBox(dr_select, xaxis, yaxis, color, datashading_switch)
    if show_widgets:
        app = pn.Row(widgets, plot_pane)
    else:
        app = pn.Row(plot_pane)
    return app

In [None]:
featuremap_app(adata, dim_reduction='X_umap', color_by='leiden')

In [None]:
# layout = layout_dimreduction_plots(
#     adata=adata,
#     color=["leiden", "log1p_total_counts", "pct_counts_mt", "log1p_n_genes_by_counts"],
#     dimensions=[(0, 1)]*4,
#     dim_reduction='X_umap',
# )
# layout.cols(2)