# Squidpy spatial scatter with HoloViz and Bokeh

This example shows how to use `squidpy.pl.spatial_scatter` to plot
annotations and features stored in `anndata.AnnData`.

This plotting is useful when points and underlying image are available.

1. Extract the image from `adata.uns["spatial"][library_id]["images"]["hires"]` (or "lowres").
2. Convert the image to a format that `holoviews.RGB` can consume.
3. Extract the point coordinates from `adata.obsm["spatial"]`.
4. Overlay the image with `holoviews.Points` of the spatial coordinates.
5. Optionally overlay the spatial connectivities

In [None]:
import squidpy as sq
from scipy.sparse import coo_matrix

import holoviews as hv
import panel as pn
import numpy as np
import pandas as pd
import param

hv.extension("bokeh")

In [None]:
adata = sq.datasets.visium_hne_adata()

## Extract the image and coordinates data

In [None]:
library_id = "V1_Adult_Mouse_Brain"

# Extract the hires image, shape = (height, width, 3)
img_data = adata.uns["spatial"][library_id]["images"]["hires"]
(height, width, _) = img_data.shape

# The raw spot coordinates (often "full-resolution"), shape = (n_obs, 2)
coords_fullres = adata.obsm["spatial"]

# For hires images, multiply coords by 'tissue_hires_scalef'
scalef = adata.uns["spatial"][library_id]["scalefactors"]["tissue_hires_scalef"]
coords_hires = coords_fullres * scalef

## Transform the image into hv.RGB

In [None]:
img_data_uint8 = (img_data * 255).astype(np.uint8)

red   = img_data_uint8[:, :, 0]
green = img_data_uint8[:, :, 1]
blue  = img_data_uint8[:, :, 2]

xvals = np.arange(width)
yvals = np.arange(height)

hv_rgb = hv.RGB(
    (xvals, yvals, red, green, blue),
    kdims=["x", "y"], 
    vdims=["R", "G", "B"]
).opts(data_aspect=1)
hv_rgb

## Get gene expression and overlay hv.Points continuous scatterplot

In [None]:
gene_var = 'Sox8'
gene_expr = adata[:, gene_var].X
if hasattr(gene_expr, "toarray"):
    gene_expr = gene_expr.toarray()
gene_expr = gene_expr.flatten()

df_expr = pd.DataFrame({
    "x": coords_hires[:, 0],
    "y": coords_hires[:, 1],
    gene_var: gene_expr,
})

points_expr = hv.Points(
    df_expr,
    kdims=["x", "y"],
    vdims=[gene_var]
).opts(
    color=gene_var,    
    cmap="viridis",    
    colorbar=True,
    tools=["hover"],   
    size=2,            
    invert_yaxis=True,
    line_alpha=0,
    title=gene_var,
    xaxis='bare',
    yaxis='bare',
    scalebar=True,
    scalebar_opts= {'bar_length':.1},
    scalebar_unit=('µm', 'm'),
)

overlay_cont = hv_rgb.opts(data_aspect=1) * points_expr
overlay_cont

TODO: either implement scale multiplier in bokeh or map xvals, yvals, coords to something that we can use a direct mapping with e.g. `scalebar_unit=('µm', 'm')`

## Get cluster data and overlay hv.Points categorical scatterplot

In [None]:
cat_var = 'cluster'
cat_data = adata.obs[cat_var].astype(str).tolist()

df_clu = pd.DataFrame({
    "x": coords_hires[:, 0],
    "y": coords_hires[:, 1],
    cat_var: cat_data,
}).sort_values(cat_var)

points_clu = hv.Points(
    df_clu,
    kdims=["x", "y"],
    vdims=[cat_var],
).opts(
    color=cat_var,  
    cmap="glasbey_light",
    tools=["hover"],
    size=2,
    legend_position='right',
    legend_cols=2,
    invert_yaxis=True,
    line_alpha=0,
    scalebar=True,
    scalebar_opts= {'bar_length':.1},
    scalebar_unit=('µm', 'm'),
    title=cat_var,
)

overlay_cat = hv_rgb.opts(data_aspect=1) * points_clu
overlay_cat

## Spatial Connectivities

In [None]:
sq.gr.spatial_neighbors(adata)

In [None]:
# 4) Build adjacency lines (edges) from spatial_connectivities
#    We'll store them in a DataFrame for hv.Segments.
#    If 'spatial_connectivities' is a sparse matrix, convert to COO to extract row/col indices.
adj = adata.obsp["spatial_connectivities"]
if not isinstance(adj, coo_matrix):
    adj = coo_matrix(adj)  # convert to COO for .row, .col
rows, cols = adj.row, adj.col

# Create a DataFrame with line segments: (x0, y0) -> (x1, y1)
df_edges = pd.DataFrame({
    "x0": coords_hires[rows, 0],
    "y0": coords_hires[rows, 1],
    "x1": coords_hires[cols, 0],
    "y1": coords_hires[cols, 1],
})

# Create hv.Segments object. We won't color edges by any attribute, just a simple grey line.
# We'll handle inversion of y-axis by flipping in the final overlay to match the scatter points.
edges = hv.Segments(df_edges, kdims=["x0", "y0", "x1", "y1"]).opts(
    color="white", line_width=0.5, line_alpha=0.9
)
(overlay_cat * edges).opts(hv.opts.Points(size=10, xlim=(200,600), ylim=(200,600)))

## Panel app

In [None]:
def create_points_overlay(
    adata,
    column: str,
    coords: np.ndarray,
    hv_rgb: hv.RGB,
    edges: hv.Segments | None = None,
    invert_y: bool = True,
    point_size: int = 3,
):
    """
    Create HoloViews overlay for the given column in adata (either a gene from adata.X columns or a column in adata.obs).
    
    Treat float dtypes as continuous, anything else as categorical.
    """
    # gene or obs?
    if column in adata.var_names:
        values = adata[:, column].X
        if hasattr(values, "toarray"):
            values = values.toarray()
        values = values.flatten()
        col_name = column
    elif column in adata.obs.columns:
        values = adata.obs[column].values
        col_name = column
    else:
        raise ValueError(f"Column '{column}' not found in adata.var_names or adata.obs.")
    
    # simplification: if float -> continuous; otherwise -> categorical
    is_continuous = pd.api.types.is_float_dtype(values)

    df_plot = pd.DataFrame({"x": coords[:, 0], "y": coords[:, 1], col_name: values})

    points = hv.Points(
        df_plot,
        kdims=["x", "y"],
        vdims=[col_name]
    )

    if is_continuous:
        points = points.opts(
            color=col_name,
            cmap="viridis",
            colorbar=True,
            tools=["hover"],
            size=point_size,
            line_alpha=0,
            title=f"{col_name}",
        )
    else:
        points = points.opts(
            color=col_name,
            cmap="glasbey_light",
            tools=["hover"],
            size=point_size,
            line_alpha=0,
            legend_position='right',
            legend_cols=2,
            title=f"{col_name}",
        )

    base = hv_rgb.opts(
        data_aspect=1,
        scalebar=True,
        scalebar_opts= {'bar_length':.1},
        scalebar_unit=('µm', 'm'),)
    if edges is not None:
        base = base * edges
    final = (base * points).opts(invert_yaxis=invert_y)

    return final


In [None]:
class VisiumViewer(pn.viewable.Viewer):
    color_col = param.Selector(
        doc="Select a gene or obs column to color by."
    )
    show_edges = param.Boolean(default=False, doc="Toggle spatial edges")

    def __init__(self, adata, coords_hires, hv_rgb, edges, **params):
        super().__init__(**params)
        self.adata = adata
        self.coords = coords_hires
        self.hv_rgb = hv_rgb
        self.edges = edges

        # Build the list of valid columns: gene or obs
        # We'll combine them e.g. "[gene1, gene2, ..., obs1, obs2]"
        genes = list(self.adata.var_names)
        obs_cols = list(self.adata.obs.columns)
        all_cols = genes + obs_cols
        self.param.color_col.objects = all_cols
        # if "cluster" in all_cols:
        #     self.color_col = "cluster"  # default
            

    @param.depends("color_col", "show_edges")
    def view(self):
        use_edges = self.edges if self.show_edges else None
        overlay = create_points_overlay(
            self.adata,
            self.color_col,
            self.coords,
            self.hv_rgb,
            edges=use_edges,
            point_size=3,
        )
        return overlay

    def __panel__(self):
        return pn.Column(
            pn.pane.Markdown("## Visium HoloViz Viewer"),
            pn.Row(
            pn.Param(
                self.param,
                widgets={
                    "color_col": {"type": pn.widgets.Select, "name": "Color Column"},
                    "show_edges": {"type": pn.widgets.Checkbox, "name": "Show Edges?"},
                },
                show_name=False,
            ),
            self.view,
            )
        )


In [None]:
cont_app_demo = VisiumViewer(adata, coords_hires, hv_rgb, edges, color_col='Sox8').servable()
cat_app_demo = VisiumViewer(adata, coords_hires, hv_rgb, edges, color_col='cluster').servable()

pn.Column(cont_app_demo, cat_app_demo)