In [None]:
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "anywidget",
#     "dimbridge",
#     "jupyter-scatter>=0.21.0",
#     "llvmlite>=0.44.0",
#     "numpy",
#     "pandas",
#     "torch",
#     "umap-learn>=0.5.7",
# ]
#
# [tool.uv]
# exclude-newer = "2025-02-17T20:44:33.141251-05:00"
# ///

# Jupyter-Scatter with DimBridge

> To run this notebook, install [juv](https://github.com/manzt/juv) and call `juv run dimbridge.ipynb`.

In this notebook I created a custom UI for [Jupyter-Scatter](https://jupyter-scatter.dev) to call [DimBridge](https://arxiv.org/abs/2404.07386). DimBridge is a visual analytics method that helps users interpret patterns in embedding visualizations by using predicate logic to create an interpretable "bridge" between the low-dimensional projection space and the original high-dimensional data space. I.e., say you select a cluster of interest in the embedding visualization, DimBridge tries to explain the clusters in terms of the most important data dimensions of the original high-dimensional data space. To achieve this, DimBridge uses a novel predicate regression algorithm that finds optimal predicates (logical expressions of data dimensions and value ranges) to explain patterns in dimensionality-reduced data by approximating high-dimensional bounding cuboids via a differentiable proxy function.

For instance, say you see a cluster of cells in a UMAP or t-SNE embedding, DimBridge can help you find out which protein/gene expressions explain the formation or difference of the selected cluster from the rest of the data.

For this demo, we're going to load a single-cell surface protein dataset from [Mair et al. (2022)](https://www.nature.com/articles/s41586-022-04718-w) that was clustered with [Greene et al.'s (2021) FAUST method](https://www.cell.com/patterns/fulltext/S2666-3899(21)00234-8) to derive cell populations and cell types and embeddded with t-SNE for visualization.

In [None]:
!mkdir -p data
!curl -L -C - -o data/mair-2022-tissue-138-ozette-tooltip.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/mair-2022-tissue-138-ozette-tooltip.pq

In [None]:
import pandas as pd

df = pd.read_parquet('data/mair-2022-tissue-138-ozette-tooltip.pq')
df.head(2)

## Additional UI Widgets

To visualize the results of DimBridge, we need a few additional widgets:
1. Histogram: a widget to visualize the expression distribution of a surface protein
2. Label: a widget to display a selection of points
3. Divider: a widget to visually add some clarity between groups of histograms

In the following we're going to create these widgets, which are all implemented with Trevor Manz's fantastic [AnyWidget](https://anywidget.dev) library.

### Histogram Widget

We're going to start with the histogram widget. Feel free to browse the code but in a nutshell, this widget displays a histogram.

In [None]:
from anywidget import AnyWidget
from traitlets import Float, Int, List, Unicode


class Histogram(AnyWidget):
    _esm = """
    function getCssVariable(variableName) {
      return window.getComputedStyle(document.body).getPropertyValue(variableName);
    }

    function getFontSizeInPixels(variableName) {
      const variable = getCssVariable(variableName);
      if (variable.endsWith('px')) {
        return Number(variable.slice(0, -2));
      }
      if (variable.endsWith('em') || variable.endsWith('rem')) {
        const relativeNumber = Number(variable.match(/\\d+/)[0]);
        const baseFontSizeInPixels = Number(window.getComputedStyle(document.body).getPropertyValue('font-size').match(/\\d+/)[0]);
        return baseFontSizeInPixels * relativeNumber
      }
      return undefined;
    }
    
    function render({ model, el }) {
      const canvas = document.createElement("canvas");
      canvas.classList.add('jupyter-scatter-histogram');
      
      const update = async () => {
        const histogram = model.get('histogram');
        const max = model.get('max');
        const height = model.get('height');
        const color = model.get('color');
        
        const title = model.get('title');
        const titleColor = model.get('title_color');
        
        const backgroundHistogram = model.get('background_histogram');
        const backgroundColor = model.get('background_color');
        
        const interval = model.get('interval');
        const intervalHeight = model.get('interval_height');
        const intervalPadding = model.get('interval_padding');

        const histogramHeight = interval
          ? height - intervalHeight - intervalPadding
          : height;

        canvas.style.height = `${height}px`;

        const ctx = canvas.getContext('2d');
        ctx.clearRect(0, 0, canvas.width, canvas.height);

        if (!histogram?.length) {
          return;
        }

        const numBins = histogram.length;
        const backgroundnumBins = backgroundHistogram?.length;
        const maxValue = !max
          ? histogram.reduce((max, v, i) => Math.max(max, v, backgroundHistogram?.[i] || 0), 0)
          : max;
        const { width } = canvas.getBoundingClientRect();
        const binWidth = (width - (numBins - 1)) / numBins;

        const dpr = window.devicePixelRatio;
        canvas.width = width * dpr;
        canvas.height = height * dpr;

        const draw = (values, color) => {
          ctx.fillStyle = color;
        
          let x = 0
          for (const value of values) {
            const binHeight = value / maxValue * histogramHeight;
            ctx.fillRect(
              x * dpr,
              (histogramHeight - binHeight) * dpr,
              binWidth * dpr,
              binHeight * dpr
            );
            x += binWidth + 1;
          }
        }

        if (backgroundHistogram) {
          draw(backgroundHistogram, backgroundColor);
        }

        draw(histogram, color);

        if (title) {
          const fontFamily = getCssVariable('--jp-ui-font-family') || 'sans-serif';
          const fontSize = getFontSizeInPixels('--jp-ui-font-size0') || 10;
          ctx.font = `${fontSize * dpr}px ${fontFamily}`;
          ctx.textAlign = 'center';
          ctx.textBaseline = 'hanging';
          ctx.strokeStyle = getCssVariable('--jp-layout-color0');
          ctx.lineWidth = 4;
          ctx.strokeText(title, canvas.width / 2, 0);
          ctx.fillStyle = titleColor || color;
          ctx.fillText(title, canvas.width / 2, 0);
        }

        if (interval) {
          ctx.fillStyle = backgroundColor;
          ctx.fillRect(
            0,
            (histogramHeight + intervalPadding) * dpr,
            canvas.width,
            intervalHeight * dpr
          );
          ctx.fillStyle = color;
          ctx.fillRect(
            interval[0] * canvas.width,
            (histogramHeight + intervalPadding) * dpr,
            (interval[1] - interval[0]) * canvas.width,
            intervalHeight * dpr
          );
        }
      }
      
      model.on("change", update);

      el.appendChild(canvas);

      const resizeObserver = new ResizeObserver(() => {
        window.requestAnimationFrame(() => {
          update();
        });
      });
        
      resizeObserver.observe(canvas);

      window.requestAnimationFrame(() => {
        update();
      });

      return () => {
        resizeObserver.disconnect();
      }
    }
    export default { render };
    """

    _css = """
    .jupyter-scatter-histogram {
      width: 100%;
    }
    """

    histogram = List(Float(), allow_none=True).tag(sync=True)
    max = Float(None, allow_none=True).tag(sync=True)
    height = Int(80).tag(sync=True)
    color = Unicode('black').tag(sync=True)

    title = Unicode('').tag(sync=True)
    title_color = Unicode('').tag(sync=True)

    background_histogram = List(Float(), default_value=None, allow_none=True).tag(
        sync=True
    )
    background_color = Unicode('gray').tag(sync=True)

    interval = List(Float(), default_value=None, allow_none=True).tag(sync=True)
    interval_height = Int(4).tag(sync=True)
    interval_padding = Int(2).tag(sync=True)

### Label Widget

The next widget we're going to create is for representing a selection of points as a label. Nothing fancy here. The key thing we're going to use this for is to (a) tell you which points you have selected, (b) allow you to delete a selection, and (c) enable you to zoom to the selected points upon clicking on the label.

In [None]:
from anywidget import AnyWidget
from traitlets import Bool, Dict, Unicode


class Label(AnyWidget):
    _esm = """
    function render({ model, el }) {
      const label = document.createElement("div");
      label.classList.add(
        'jupyter-widgets',
        'jupyter-scatter-label'
      );
      label.tabIndex = 0;
      
      const update = () => {
        label.textContent = model.get('name');

        for (const [key, value] of Object.entries(model.get('style'))) {
          label.style[key] = value;
        }
      }
      
      model.on('change:name', update);
      model.on('change:style', update);

      update();

      const createFocusChanger = (value) => () => {
        model.set('focus', value);
        model.save_changes();
      }

      const focusHandler = createFocusChanger(true);
      const blurHandler = createFocusChanger(false);

      label.addEventListener('focus', focusHandler);
      label.addEventListener('blur', blurHandler);

      el.appendChild(label);

      const updateFocus = () => {
        if (model.get('focus')) {
          label.focus();
        }
      }
      
      model.on('change:focus', updateFocus);

      window.requestAnimationFrame(() => {
        updateFocus();
      });

      return () => {
        label.removeEventListener('focus', focusHandler);
        label.removeEventListener('blur', blurHandler);
      }
    }
    export default { render };
    """

    _css = """
    .jupyter-scatter-label {
      display: flex;
      align-items: center;
      width: 100%;
      height: var(--jp-widgets-inline-height);
      padding: var(--jp-widgets-input-padding) calc(var(--jp-widgets-input-padding)* 2);
      border-top-left-radius: var(--jp-border-radius);
      border-rop-right-radius: 0;
      border-bottom-left-radius: var(--jp-border-radius);
      border-bottom-right-radius: 0;
    }
    .jupyter-scatter-label:focus {
      font-weight: bold;
      outline: 1px solid var(--jp-widgets-input-focus-border-color);
      outline-offset: 1px;
    }
    """

    name = Unicode('').tag(sync=True)
    style = Dict({}).tag(sync=True)
    focus = Bool(False).tag(sync=True)

### Divider Widget

And finally, the technically most challenging (ahhh boring) widget for rendering a dividing horizontal line. Please don't waste time looking at the code as there's nothing to see here.

In [None]:
from anywidget import AnyWidget
from traitlets import Bool, Dict, Unicode


class Div(AnyWidget):
    _esm = """
    function render({ model, el }) {
      const div = document.createElement("div");
      div.classList.add(
        'jupyter-widgets',
        'jupyter-scatter-div'
      );
      
      const update = () => {
        for (const [key, value] of Object.entries(model.get('style'))) {
          div.style[key] = value;
        }
      }
      
      model.on('change', update);

      update();

      el.appendChild(div);
    }
    export default { render };
    """

    style = Dict({}).tag(sync=True)

### Helper Functions

Next, we're going to define a bunch of helper functions for splitting a brush selection into a subset of `N` consecutive polygons. E.g., say you brush from left to right as follows:

```
|---------------------------->
```

If you choose to subdivide this selection into `5` subselections, the result will be something like this:

```
|---->|---->|---->|---->|---->
```

To get a feel for it, switch Jupyter-Scatter's lasso type to `brush`, select something, and then check the box to `subdivide` the selection.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.path import Path
from scipy.spatial.distance import cdist


def find_equidistant_vertices(vertices: np.ndarray, n_points: int) -> np.ndarray:
    """
    Find N equidistant points along a line defined by vertices.

    Args:
        vertices: numpy array of shape (M, 2) containing the line vertices
        n_points: number of equidistant points to find

    Returns:
        numpy array of shape (n_points, 2) containing the equidistant points
    """
    # Calculate cumulative distances along the line
    segment_vectors = np.diff(vertices, axis=0)
    segment_lengths = np.linalg.norm(segment_vectors, axis=1)
    cumulative_lengths = np.concatenate(([0], np.cumsum(segment_lengths)))
    total_length = cumulative_lengths[-1]

    # Calculate target distances for equidistant points
    target_distances = np.linspace(0, total_length, n_points)

    # Initialize output array
    equidistant_points = np.zeros((n_points, 2))

    # For each target distance, find the corresponding point
    for i, target_dist in enumerate(target_distances):
        # Find segment containing the target distance
        segment_idx = np.searchsorted(cumulative_lengths, target_dist)
        segment_idx = min(segment_idx, len(vertices) - 2)

        # Calculate interpolation factor within the segment
        segment_start_dist = cumulative_lengths[segment_idx]
        segment_length = segment_lengths[segment_idx]
        alpha = (target_dist - segment_start_dist) / segment_length

        # Interpolate point
        start_point = vertices[segment_idx]
        end_point = vertices[segment_idx + 1]
        equidistant_points[i] = start_point + alpha * (end_point - start_point)

    return equidistant_points


def split_line_at_points(
    vertices: np.ndarray, split_points: np.ndarray
) -> list[np.ndarray]:
    """
    Split a line into segments between the given split points, excluding the first and last segments.

    Args:
        vertices: numpy array of shape (M, 2) containing the original line vertices
        split_points: numpy array of shape (N, 2) containing the equidistant points

    Returns:
        list of N-1 numpy arrays, each containing a segment of the original line between
        consecutive split points
    """
    # Calculate cumulative distances for original vertices
    segment_vectors = np.diff(vertices, axis=0)
    segment_lengths = np.linalg.norm(segment_vectors, axis=1)
    vertices_cum_lengths = np.concatenate(([0], np.cumsum(segment_lengths)))

    result = []

    # Process only the internal segments (between points[0] and points[-1])
    for i in range(len(split_points) - 1):
        start_point = split_points[i]
        end_point = split_points[i + 1]
        current_vertices = [start_point]

        # Find vertices that lie between these split points
        for j in range(len(vertices)):
            vertex = vertices[j]
            # Skip vertices before start_point or after end_point
            if np.linalg.norm(vertex - start_point) < 1e-10:
                continue
            if np.linalg.norm(vertex - end_point) < 1e-10:
                continue

            # Check if vertex lies between start and end points
            # Using path distance to ensure correct ordering
            vertex_dist = vertices_cum_lengths[j]
            start_dist = vertices_cum_lengths[
                np.argmin(np.linalg.norm(vertices - start_point, axis=1))
            ]
            end_dist = vertices_cum_lengths[
                np.argmin(np.linalg.norm(vertices - end_point, axis=1))
            ]

            if start_dist < vertex_dist < end_dist:
                current_vertices.append(vertex)

        current_vertices.append(end_point)
        result.append(np.array(current_vertices))

    return result


def split_line_equidistant(vertices: np.ndarray, n_points: int) -> list[np.ndarray]:
    """
    Split a line into n_points-1 segments at n_points equidistant points.

    Args:
        vertices: numpy array of shape (M, 2) containing the line vertices
        n_points: number of points to split at

    Returns:
        list of n_points-1 numpy arrays, each containing a segment between
        consecutive equidistant points
    """
    equidistant_points = find_equidistant_vertices(vertices, n_points)
    return split_line_at_points(vertices, equidistant_points)


def points_in_polygon(points: np.ndarray, polygon: np.ndarray) -> np.ndarray:
    """
    Test if points are inside a polygon.

    Args:
        points: numpy array of shape (N, 2) containing test points
        polygon: numpy array of shape (M, 2) containing polygon vertices

    Returns:
        numpy array of shape (N,) containing boolean values
    """
    path = Path(polygon)
    return path.contains_points(points)

### Widget Composition

Finally, we're going to instantiate the scatter plot and all the other widgets and link them using their traits. The output is the UI you've been waiting for :)

In [None]:
import numpy as np

from collections import OrderedDict
from dataclasses import dataclass, field
from dimbridge.predicate_engine import compute_predicate_sequence
from IPython.display import HTML
from ipywidgets import Checkbox, Dropdown, GridBox, HBox, Layout, IntText, Text, VBox
from itertools import cycle
from jscatter import Scatter, glasbey_light, link, okabe_ito, Line
from jscatter.widgets import Button
from numpy import histogram, isnan
from matplotlib.colors import to_hex
from scipy.spatial import ConvexHull

all_colors = okabe_ito.copy()
available_colors = [color for color in all_colors]

color_by = 'cell_population'

non_protein_cols = ['x', 'y', 'cell_population', 'cell_type']
protein_cols = [col for col in df.columns if col not in non_protein_cols]

color_map = dict(zip(df[color_by].unique(), cycle(glasbey_light[1:])))
color_map['Non-robust'] = (0.2, 0.2, 0.2, 1.0)

continuous_color_maps = [
    ['#00dadb', '#da00db'],
    ['#00dadb', '#a994dc', '#da00db'],
    ['#00dadb', '#8faddc', '#bd77dc', '#da00db'],
    ['#00dadb', '#7eb9dc', '#a994dc', '#c567dc', '#da00db'],
    ['#00dadb', '#72c0db', '#9aa3dc', '#b583dc', '#ca5cdb', '#da00db'],
    ['#00dadb', '#69c4db', '#8faddc', '#a994dc', '#bd77dc', '#cd54db', '#da00db'],
    [
        '#00dadb',
        '#62c7db',
        '#86b4dc',
        '#9e9fdc',
        '#b288dc',
        '#c16edc',
        '#cf4ddb',
        '#da00db',
    ],
    [
        '#00dadb',
        '#5ccadb',
        '#7eb9dc',
        '#96a7dc',
        '#a994dc',
        '#b87fdc',
        '#c567dc',
        '#d048db',
        '#da00db',
    ],
    [
        '#00dadb',
        '#57ccdb',
        '#78bddc',
        '#8faddc',
        '#a19ddc',
        '#b08bdc',
        '#bd77dc',
        '#c861db',
        '#d144db',
        '#da00db',
    ],
]

scatter = Scatter(
    data=df,
    x='x',
    y='y',
    background_color='#111111',
    axes=False,
    height=720,
    color_by=color_by,
    color_map=color_map,
    tooltip=True,
    tooltip_properties=[
        'color',
        'cell_type',
        'CD3',
        'CD4',
        'CD8',
        'CD19',
        'CD27',
        'CD45RA',
    ],
    tooltip_histograms_size='large',
)

# Remove non-robust cell populations from the histogram as it's uninteresting
color_histogram = scatter.widget.color_histogram.copy()
color_histogram[scatter._color_categories['Non-robust']] = 0
scatter.widget.color_histogram = color_histogram


@dataclass
class Selection:
    """Class for keeping track of a selection."""

    index: int
    name: str
    points: np.ndarray
    color: str
    lasso: Line
    hull: Line


@dataclass
class Selections:
    """Class for keeping track of selections."""

    selections: list[Selection] = field(default_factory=list)

    def all_points(self) -> np.ndarray:
        return np.unique(
            np.concatenate(
                list(map(lambda selection: selection.points, self.selections))
            )
        )

    def all_hulls(self) -> list[Line]:
        return [s.hull for s in self.selections]


@dataclass
class Lasso:
    """Class for keeping track of the lasso polygon."""

    polygon: Line | None = None


lasso = Lasso()
selections = Selections()


def update_annotations():
    lasso_polygon = [] if lasso.polygon is None else [lasso.polygon]
    scatter.annotations(selections.all_hulls() + lasso_polygon)


def lasso_selection_polygon_change_handler(change):
    if change['new'] is None:
        lasso.polygon = None
    else:
        points = change['new'].tolist()
        points.append(points[0])
        lasso.polygon = Line(points, line_color=scatter.widget.color_selected)
    update_annotations()


scatter.widget.observe(
    lasso_selection_polygon_change_handler, names=['lasso_selection_polygon']
)

selection_name = Text(value='', placeholder='Select some points…', disabled=True)
selection_name.layout.width = '100%'

selection_add = Button(
    description='',
    tooltip='Save Selection',
    disabled=True,
    icon='plus',
    width=36,
    rounded=['top-right', 'bottom-right'],
)

selection_subdivide = Checkbox(value=False, description='Subdivide', indent=False)

selection_num_subdivisions = IntText(
    value=5,
    min=2,
    max=10,
    step=1,
    description='Parts',
)

selection_subdivide_wrapper = HBox([selection_subdivide, selection_num_subdivisions])

selections_elements = VBox(layout=Layout(grid_gap='2px'))

selections_predicates_css = """
<style>
.jupyter-scatter-dimbridge-selections-predicates {
    position: absolute !important;
}

.jupyter-scatter-dimbridge-selections-predicates-wrapper {
    position: relative;
}
</style>
"""

display(HTML(selections_predicates_css))

selections_predicates = VBox(
    layout=Layout(
        top='4px',
        left='0px',
        right='0px',
        bottom='4px',
        grid_gap='4px',
    )
)
selections_predicates.add_class('jupyter-scatter-dimbridge-selections-predicates')

selections_predicates_wrapper = VBox(
    [selections_predicates],
    layout=Layout(
        height='100%',
    ),
)
selections_predicates_wrapper.add_class(
    'jupyter-scatter-dimbridge-selections-predicates-wrapper'
)

compute_predicates = Button(
    description='Compute Predicates',
    style='primary',
    disabled=True,
    full_width=True,
)

compute_predicates_between_selections = Checkbox(
    value=False, description='Compare Between Selections', indent=False
)

compute_predicates_wrapper = VBox([compute_predicates])


def add_selection_element(selection: Selection):
    hex_color = to_hex(selection.color)

    selection_name = Label(
        name=selection.name,
        style={'background': hex_color},
    )

    selection_remove = Button(
        description='',
        tooltip='Remove Selection',
        icon='trash',
        width=36,
        background=hex_color,
        rounded=['top-right', 'bottom-right'],
    )

    element = GridBox(
        [
            selection_name,
            selection_remove,
        ],
        layout=Layout(grid_template_columns='1fr 40px'),
    )

    def focus_handler(change):
        if change['new']:
            scatter.zoom(to=selection.points, animation=500, padding=2)
        else:
            scatter.zoom(to=None, animation=500, padding=0)

    selection_name.observe(focus_handler, names=['focus'])

    def remove_handler(change):
        selections_elements.children = [
            e for e in selections_elements.children if e != element
        ]
        selections.selections = [s for s in selections.selections if s != selection]
        update_annotations()
        compute_predicates.disabled = len(selections.selections) == 0

    selection_remove.on_click(remove_handler)

    selections_elements.children = selections_elements.children + (element,)


def add_subdivided_selections():
    lasso_polygon = scatter.widget.lasso_selection_polygon
    lasso_points = lasso_polygon.shape[0]

    lasso_mid = int(lasso_polygon.shape[0] / 2)
    lasso_spine = (lasso_polygon[:lasso_mid, :] + lasso_polygon[lasso_mid:, :]) / 2

    lasso_part_one = lasso_polygon[:lasso_mid, :]
    lasso_part_two = lasso_polygon[lasso_mid:, :][::-1]

    n_split_points = selection_num_subdivisions.value + 1

    sub_lassos_part_one = split_line_equidistant(lasso_part_one, n_split_points)
    sub_lassos_part_two = split_line_equidistant(lasso_part_two, n_split_points)

    base_name = selection_name.value
    if len(base_name) == 0:
        base_name = f'Selection {len(selections.selections) + 1}'

    color_map = continuous_color_maps[selection_num_subdivisions.value]

    for i, part_one in enumerate(sub_lassos_part_one):
        polygon = np.vstack((part_one, sub_lassos_part_two[i][::-1]))
        idxs = np.where(points_in_polygon(df[['x', 'y']].values, polygon))[0]
        points = df.iloc[idxs][['x', 'y']].values
        hull = ConvexHull(points)
        hull_points = np.vstack((points[hull.vertices], points[hull.vertices[0]]))
        color = color_map[i]
        name = f'{base_name}.{i + 1}'

        lasso_polygon = polygon.tolist()
        lasso_polygon.append(lasso_polygon[0])

        selection = Selection(
            index=len(selections.selections) + 1,
            name=name,
            points=idxs,
            color=color,
            lasso=Line(lasso_polygon),
            hull=Line(hull_points, line_color=color, line_width=2),
        )
        selections.selections.append(selection)
        add_selection_element(selection)


def add_selection():
    idxs = scatter.selection()
    points = df.iloc[idxs][['x', 'y']].values
    hull = ConvexHull(points)
    hull_points = np.vstack((points[hull.vertices], points[hull.vertices[0]]))
    color = available_colors.pop(0)

    name = selection_name.value
    if len(name) == 0:
        name = f'Selection {len(selections.selections) + 1}'

    lasso_polygon = scatter.widget.lasso_selection_polygon.tolist()
    lasso_polygon.append(lasso_polygon[0])

    selection = Selection(
        index=len(selections.selections) + 1,
        name=name,
        points=idxs,
        color=color,
        lasso=Line(lasso_polygon),
        hull=Line(hull_points, line_color=color, line_width=2),
    )
    selections.selections.append(selection)
    add_selection_element(selection)


def selection_add_handler(event):
    lasso.polygon = None

    if scatter.widget.lasso_type == 'brush' and selection_subdivide.value:
        add_subdivided_selections()
    else:
        add_selection()

    compute_predicates.disabled = False

    scatter.selection([])
    update_annotations()

    if len(selections.selections) > 1:
        compute_predicates_wrapper.children = (
            compute_predicates_between_selections,
            compute_predicates,
        )
    else:
        compute_predicates_wrapper.children = (compute_predicates,)


selection_add.on_click(selection_add_handler)


def selection_handler(change):
    if len(change['new']) > 0:
        selection_add.disabled = False
        selection_name.disabled = False
        selection_name.placeholder = 'Name selection…'
        new_index = 1
        if len(selections.selections) > 0:
            new_index = selections.selections[-1].index + 1
        selection_name.value = f'Selection {new_index}'
    else:
        selection_add.disabled = True
        selection_name.disabled = True
        selection_name.placeholder = 'Select some points…'
        selection_name.value = ''


scatter.widget.observe(selection_handler, names=['selection'])


def clear_predicates(event):
    compute_predicates.style = 'primary'
    compute_predicates.description = 'Compute Predicates'
    compute_predicates.on_click(compute_predicates_handler)

    selections_predicates.children = ()

    if len(selections.selections) > 1:
        compute_predicates_wrapper.children = (
            compute_predicates_between_selections,
            compute_predicates,
        )
    else:
        compute_predicates_wrapper.children = (compute_predicates,)


def show_predicates(predicates_by_selection, bins=30):
    compute_predicates.style = ''
    compute_predicates.description = 'Clear Predicates'
    compute_predicates.on_click(clear_predicates)

    d = OrderedDict()

    for i, predicates in enumerate(predicates_by_selection):
        for predicate in predicates:
            if predicate['attribute'] in d:
                d[predicate['attribute']].append(
                    {
                        'selection': selections.selections[i],
                        'predicate': predicate,
                    }
                )
            else:
                d[predicate['attribute']] = [
                    {
                        'selection': selections.selections[i],
                        'predicate': predicate,
                    }
                ]

    for column, predicate_by_selections in d.items():
        data = df[column]
        bg_values = data[~isnan(data)]
        bg_hist, bg_bins = histogram(bg_values, bins=bins)

        # Background distribution
        bg_hist_plot = Histogram(
            histogram=list(map(float, bg_hist)),
            color='#333',
            height=24,
            title=column,
            title_color='white',
        )

        if i > 0:
            divider = Div(
                style={
                    'height': '1px',
                    'margin': '4px 0 2px 0',
                }
            )
            selections_predicates.children = selections_predicates.children + (
                divider,
                bg_hist_plot,
            )
        else:
            selections_predicates.children = selections_predicates.children + (
                bg_hist_plot,
            )

        for predicate_by_selection in predicate_by_selections:
            predicate = predicate_by_selection['predicate']
            selection = predicate_by_selection['selection']
            points = selection.points
            # data = df[predicate['column']]

            # bg_values = data[~isnan(data)]
            # bg_hist, bg_bins = histogram(bg_values, bins=bins)

            min_value, max_value = predicate['interval']
            values = data.iloc[points]
            values = values[~isnan(values)]
            hist, _ = histogram(values, bins=bg_bins)

            # Foreground
            hist_plot = Histogram(
                histogram=list(map(float, hist)),
                color=to_hex(selection.color),
                background_color='#333',
                height=24,
                interval=predicate['interval'],
                interval_height=3,
                interval_padding=3,
            )

            # Combined
            # bg_hist_plot = Histogram(
            #     histogram=list(map(float, hist)),
            #     background_histogram=list(map(float, bg_hist)),
            #     color=to_hex(selection.color),
            #     background_color='#333',
            #     height=32,
            #     title=predicate['column'],
            #     interval=predicate['interval'],
            #     interval_height=4,
            #     interval_padding=4,
            # )

            selections_predicates.children = selections_predicates.children + (
                hist_plot,
            )


def compute_predicates_handler(event):
    if len(selections.selections) == 0:
        return

    compute_predicates.disabled = True
    compute_predicates.description = 'Computing Predicates…'

    num_selections = len(selections.selections)
    num_points = len(df)

    selection_masks = np.zeros((num_selections, num_points), dtype=bool)

    for i, selection in enumerate(selections.selections):
        selection_masks[i, selection.points] = True

    df_subset = df[protein_cols]

    if num_selections > 1 and compute_predicates_between_selections.value:
        df_subset = df_subset.iloc[selections.all_points()]
        selection_masks = selection_masks[:, selections.all_points()]

    predicates, qualities, parameters = compute_predicate_sequence(
        x0=df_subset.values,
        selected=selection_masks,
        attribute_names=protein_cols,
        n_iter=500,
    )

    show_predicates(predicates)

    compute_predicates.disabled = False


compute_predicates.on_click(compute_predicates_handler)

add = GridBox(
    [
        selection_name,
        selection_add,
    ],
    layout=Layout(grid_template_columns='1fr 40px'),
)

complete_add = VBox([add], layout=Layout(grid_gap='4px'))


def lasso_type_change_handler(change):
    if change['new'] == 'brush':
        complete_add.children = (add, selection_subdivide_wrapper)
    else:
        complete_add.children = (add,)


scatter.widget.observe(lasso_type_change_handler, names=['lasso_type'])

color_by = Dropdown(
    options=[('Cell Population', 'cell_population')] + [(p, p) for p in protein_cols],
    value=color_by,
    description='Color By:',
)


def color_by_change_handler(change):
    cmap = color_map if change['new'] == 'cell_population' else 'magma'
    scatter.color(by=change['new'], map=cmap)


color_by.observe(color_by_change_handler, names=['value'])

plot_wrapper = VBox([scatter.show(), color_by])

sidebar = GridBox(
    [
        complete_add,
        selections_elements,
        selections_predicates_wrapper,
        compute_predicates_wrapper,
    ],
    layout=Layout(
        grid_template_rows='min-content max-content 1fr min-content',
    ),
)

GridBox(
    [
        plot_wrapper,
        sidebar,
    ],
    layout=Layout(grid_template_columns='1fr minmax(10rem, 20rem)'),
)

---

To test DimBridge:
1. Select some points using Jupyter-Scatter's lasso
2. Add the selection using the `plus` button in the top-right corner
3. Click on `Compute Predicates`

Once DimBridge is done, you'll see a set of histograms that represent the dimensions that best describe the cluster in the data space. I.e., the key protein expressions that differentiate that cluster from the rest of the dataset. You can quickly verify the expression of the identified protein expression by changing the color encoding via the drop-down menu in the bottom-left corner.

There are three modes of operation:
1. Explain a single cluster
2. Compare multiple clusters
3. Explore a sequence of clusters

For mode 2, add two or more selections and then select `Compare Between Selections`.

For mode 3, select Jupyter-Scatter's `brush` select using the third button from the top-left corner. Then select some points and activate `Subdivide` priot to adding the selection.