# Text Labels

In this notebook we're going to demonstrate how annotate and display points with text labels in Jupyter Scatter using three examples:
1. Geospatial city dataset
2. Single-cell embedding
3. arXiv ML paper embedding

> **NOTE:**
>
> In order to run this notebook you need to have Jupyter Scatter `v0.22.0` or higher installed.

In [None]:
# If you run this notebook in Google Colab, you need to manually install the following packages.
# !pip install --quiet jupyter-scatter

## Geospatial City Dataset

For the first demo, we're going to load cities across the world from the [GeoNames dataset](https://www.geonames.org/about.html).

In [None]:
!curl -L -C - -o data/cities.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/cities.pq

In [None]:
import pandas as pd
df_cities = pd.read_parquet('data/cities.pq')
df_cities.head(3)

And we're going to label points by contintent. 

In [None]:
from jscatter import Scatter

# Basic scatter plot configuration that we're going to reuse a few times
scatter_cities_base_config = dict(
    data=df_cities,
    x='Mercator X',
    y='Mercator Y',
    color_by='Continent',
    size=2,
    axes=False,
    legend=True,
    height=640,
)

scatter_cities = Scatter(**scatter_cities_base_config)

# The column by which we want to label points
scatter_cities.label(by='Continent')

scatter_cities.show()

We can customize the appearance of the text labels in various ways if we like:

In [None]:
scatter_cities.label(
    # For better visibility we're using a bold font
    font='arial bold',
    # Place labels at the center of mass of the largest
    # cluster. This is useful when the labeled points
    # correspond to multiple, disconnected clusters
    positioning='largest_cluster',
    # The inverse hyperbolic sine scale function that
    # labels get enlarged sublinearly as you zoon in.
    scale_function='asinh'
)

While it's nice to see labels for the different continents, wouldn't it be even cooler to also see labels for the actual cities? Since there are more than 100k cities in the dataset, the static label placement can take a moment to compute. The good news is that we can do this computation upfront using the new `LabelPlacement` directly.

A couple of notes:
- To label points in multiple ways, we can simply pass multiple columns to `by`.
- Since we want to treat each city name as a unique (i.e., point) label we need to append an exclamation mark to the column name. This is necessary as city names are not unique. For instance, there's a city called Berlin in Germany (the capital) and another city called Berlin in Massachusetts of the United States.
- To enforce that continent labels are drawn before city labels, we can set `hierarchical` to `True`, which tells the labeler that the different label types form a strict hierarchy. Additionally, we change the importance aggregation to `"sum"` to ensure continents have higher importance than cities.

In [None]:
from jscatter import LabelPlacement, arial

label_placement = LabelPlacement(
    data=scatter_cities._data,
    x=scatter_cities._x,
    y=scatter_cities._y,
    by=['Continent', 'Name!'],
    importance='Population',
    importance_aggregation='sum',
    font=[arial.bold, arial.regular],
    size=[24, 16],
    tile_size=scatter_cities._height,
)
    
label_placement.compute(show_progress=True)

Now all we have to do is to use the precomputed labels as follows. We also adjust the label alignment from `"center"` (default) to `"top"` to draw the label above their center position. This is useful for point labels as the label otherwise superimposes the point. We additionally add 2px negative y-offset.

In [None]:
scatter_cities.label(using=label_placement, align='top', offset=(0,-2))

If you're planning to use the label often, you can store them as two parquet files for faster re-use.

In [None]:
label_placement.to_parquet('data/cities')

We can now use the precomputed labels to instantly create a new scatter plot with them.

In [None]:
from jscatter import Scatter, LabelPlacement

label_placement_2 = LabelPlacement.from_parquet('data/cities')
scatter_cities_2 = Scatter(**scatter_cities_base_config, label_using=label_placement_2)
scatter_cities_2.show()

## Single-Cell Embedding

For the next demo, we're going to load a single-cell surface protein dataset from [Mair et al. (2022)](https://www.nature.com/articles/s41586-022-04718-w) that was clustered with [Greene et al.'s (2021) FAUST method](https://www.cell.com/patterns/fulltext/S2666-3899(21)00234-8) to derive cell populations and cell types and embeddded with t-SNE for visualization.

In [None]:
!curl -L -C - -o data/mair-2022-tissue-138.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/mair-2022-tissue-138.pq

In [None]:
import pandas as pd
df_single_cell = pd.read_parquet('./data/mair-2022-tissue-138.pq')
df_single_cell.head(2)

For this example, we're going to label points by their cell (broader) type and phenotype. To better associate labels with points, we're going to color the labels by the point color.

Also, sometimes you might not want to show all labels. For instance, in the following some cells are associated with an unspecified cell type or a non-robust phenotypes. To avoid showing a label for those classes, we're going to exclude the two labels ("Other" and "Non-robust")

In [None]:
from itertools import cycle
from jscatter import Scatter, glasbey_light, okabe_ito, brighten, saturate
from traitlets import observe

# Color map specifications
non_protein_cols = ['x', 'y', 'cell_population', 'cell_type']
protein_cols = [col for col in df_single_cell.columns if col not in non_protein_cols]

cell_type_color_map = dict(zip(df_single_cell['cell_type'].unique(), cycle(okabe_ito[:7])))
cell_type_color_map['Other'] = (0.2, 0.2, 0.2, 1.0)

cell_population_color_map = dict(zip(df_single_cell['cell_population'].unique(), cycle(glasbey_light)))
cell_population_color_map['Non-robust'] = (0.2, 0.2, 0.2, 1.0)

label_color_map = {
    **{f'cell_type:{cell_type}': brighten(saturate(color, 3 / 2), 2) for cell_type, color in cell_type_color_map.items()},
    **{f'cell_population:{cell_pop}': brighten(saturate(color, 3 / 2), 2) for cell_pop, color in cell_population_color_map.items()}
}

# Scatter specifications
scatter_single_cell = Scatter(
    data=df_single_cell,
    x='x',
    y='y',
    background_color='#111111',
    axes=False,
    height=720,
    color_by='cell_type',
    color_map=cell_type_color_map,
    tooltip=True,
    tooltip_properties=['cell_type', 'CD3', 'CD4', 'CD8', 'CD19', 'CD27', 'CD45RA'],
    label_by=['cell_type', 'cell_population'],
    label_hierarchical=True,
    label_size={'cell_type': 36, 'cell_population': 12},
    label_color=label_color_map,
    label_positioning='largest_cluster',
    # We are going to restrict the zoom level of the two label types such that
    # broad cell types are shown at the beginning and fade out at level `2.25`
    # and specific cell phentypes are allowed to appear at level `1.5` at up.
    # Note that the zoom ranges can overlap as in this example.
    label_zoom_range={'cell_type': (-float('inf'), 2.25), 'cell_population': (1.5, float('inf'))},
    # To avoid showing a label for other cell types and non-robust cell phenotypes,
    # we `exclude` those. Since we have two types of labels, we need to
    # reference them their type: `cell_type:other` and `cell_population:Non-robust`.
    label_exclude=['cell_type:Other', 'cell_population:Non-robust'],
    # Since some cell phenotype labels are super long, we're going to break
    # labels into multiple lines to optimize for a label aspect ratio of `5`.
    label_target_aspect_ratio=5,
)

# At zoom level `1.5` we're going to switch the color map from broad cell types
# to specific cell phenotypes to better identify phenotypes.
def zoom_level_change_handler(change):
    if change['new'] >= 1.5:
        scatter_single_cell.color(by='cell_population', map=cell_population_color_map)
    else:
        scatter_single_cell.color(by='cell_type', map=cell_type_color_map)

scatter_single_cell.widget.observe(zoom_level_change_handler, names='zoom_level')

scatter_single_cell.show()

Zoom in to see how more and more labels appear and the color automatically changes.

# ML ArXiv Papers

The following plot shows an embedding of machine learning papers from arXiv and is an [example from Leland McInnes's excellent DataMapPlot library](https://datamapplot.readthedocs.io/en/latest/interactive_intro.html). Huge kudos for sharing the data Leland 🙏.

In [None]:
!curl -L -C - -o data/arxiv-ml.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/arxiv-ml.pq

In [None]:
import pandas as pd
df_arxiv_ml = pd.read_parquet('data/arxiv-ml.pq')
df_arxiv_ml.head(2)

### 2D Color Map

Let us first create a color map for all categories using the fantastic Schumann color map from https://pypi.org/project/pycolormap-2d/. (In case you're wonderng, we had to integrate their library as it's not yet compatible with Numpy v2 sadly.)

In [None]:
from jscatter import ColorMap2DSchumann, brighten, saturate
from matplotlib.colors import to_hex

cmap = ColorMap2DSchumann(
    range_x=(df_arxiv_ml.x.min(), df_arxiv_ml.x.max()),
    range_y=(df_arxiv_ml.y.min(), df_arxiv_ml.y.max()),
)

df_arxiv_ml_category_cmap = {}
df_arxiv_ml_label_cmap = {}

for i in range(1, 6):
    label_type = f'category_{i}'
    labels = df_arxiv_ml[label_type].dropna().unique()
    for label in labels:
        mask = df_arxiv_ml[label_type] == label
        
        cx = df_arxiv_ml[mask].x.median()
        cy = df_arxiv_ml[mask].y.median()

        color = to_hex(cmap(cx, cy) / 255)
        
        df_arxiv_ml_category_cmap[label] = color
        df_arxiv_ml_label_cmap[f'{label_type}:{label}'] = brighten(saturate(color, 3), 3)

# To color "noise" points (i.e., `NaN` values) using a dark gray, we cast `NaN` to `"NA"`
if 'NA' not in df_arxiv_ml['category'].cat.categories.tolist():
    df_arxiv_ml['category'] = df_arxiv_ml['category'].cat.add_categories('NA')
df_arxiv_ml['category'] = df_arxiv_ml['category'].fillna('NA')

df_arxiv_ml_category_cmap['NA'] = '#333333'

Next we're going to create a scatter plot as before. Since the labels are hierarchical, we're setting `hierarchical=True` to avoid showing category_1 label before category_2 labels, etc.

In [None]:
from jscatter import Scatter

arxiv_ml_scatter = Scatter(
    data=df_arxiv_ml,
    x='x',
    y='y',
    background_color='#111111',
    axes=False,
    height=720,
)
arxiv_ml_scatter.color(by='category', map=df_arxiv_ml_category_cmap)
arxiv_ml_scatter.tooltip(
    enable=True,
    preview='title',
    preview_type='text',
    properties=[],
)
arxiv_ml_scatter.label(
    by=['category_1', 'category_2', 'category_3', 'category_4', 'category_5'],
    size=[24, 20, 16, 12, 10],
    color=df_arxiv_ml_label_cmap,
    scale_function='constant',
    hierarchical=True,
    target_aspect_ratio=2,
)

arxiv_ml_scatter.show()