# Multiple Annotation Embedding

In this notebook we are showing how the annotation embedding avoids batch effects in pairwise comparisons of health and tumor samples from [Mair et al., 2022, Nature](https://www.nature.com/articles/s41586-022-04718-w).

In [1]:
import jscatter
import math
import numpy as np
import pandas as pd
import re
import transformation
import colors

from glob import glob
from openTSNE.sklearn import TSNE

In [2]:
dataset_name_tissue = 'TISSUE_138'
dataset_name_tumor = 'TUMOR_007'

dataset_tissue = glob(f'data/mair-2022/{dataset_name_tissue}*')[0]
dataset_tumor = glob(f'data/mair-2022/{dataset_name_tumor}*')[0]

sample_tissue = dataset_tissue[16:-25]
sample_tumor = dataset_tumor[16:-25]

df_tissue = pd.read_parquet(dataset_tissue)
df_tumor = pd.read_parquet(dataset_tumor)

if len(df_tissue) < len(df_tumor):
    df_tumor = df_tumor.sample(n=len(df_tissue), random_state=42)
else:
    df_tissue = df_tissue.sample(n=len(df_tumor), random_state=42)

**Joint Embedding:**

In [None]:
from importlib import reload
reload(transformation)

df_ann_embed_tsne = transformation.transform_embed(
    [df_tissue, df_tumor],
    embeddor=TSNE,
    embeddor_random_state = 42,
    save_as=f'{dataset_name_tissue}_{dataset_name_tumor}_tsne_ann',
)

df_raw_embed_tsne = transformation.transform_embed(
    [df_tissue, df_tumor],
    embeddor=TSNE,
    embeddor_random_state=42,
    embed_raw=True,
    save_as=f'{dataset_name_tissue}_{dataset_name_tumor}_tsne_raw',
)

**Visualize Embeddings:**

In [10]:
# Uncomment the line below to load previously embedded data
df_ann_embed_tsne = pd.read_parquet(f'data/{dataset_name_tissue}_{dataset_name_tumor}_tsne_ann.pq')
df_raw_embed_tsne = pd.read_parquet(f'data/{dataset_name_tissue}_{dataset_name_tumor}_tsne_raw.pq')

cell_type_color_map = colors.get_cmap(len(df_tissue.faustLabels.unique()), mode='dark')
sample_color_map = dict(
    sample_tissue='#0072B2', # blue
    sample_tumor='#E69F00', # orange
)

base_view_config = dict(x='x', y='y', opacity_unselected=0.1, background_color='black', axes=False)
cell_type_view_config = dict(color_by='cellType', color_map=cell_type_color_map, **base_view_config)
sample_view_config = dict(color_by='sampleOfOrigin', color_map=sample_color_map, **base_view_config)
marker_view_config = dict(color_by='ICOS_Windsorized', color_map='viridis', **base_view_config)

compose_config = dict(sync_selection=True, sync_hover=True, row_height=400, rows=2)

plot_ann_embed_tsne_ct = jscatter.Scatter(data=df_ann_embed_tsne, **cell_type_view_config)
plot_ann_embed_tsne_sp = jscatter.Scatter(data=df_ann_embed_tsne, **sample_view_config)
plot_ann_embed_tsne_mx = jscatter.Scatter(data=df_ann_embed_tsne, **marker_view_config)

plot_raw_embed_tsne_ct = jscatter.Scatter(data=df_raw_embed_tsne, **cell_type_view_config)
plot_raw_embed_tsne_sp = jscatter.Scatter(data=df_raw_embed_tsne, **sample_view_config)
plot_raw_embed_tsne_mx = jscatter.Scatter(data=df_raw_embed_tsne, **marker_view_config)

jscatter.compose(
    [
        plot_ann_embed_tsne_ct,
        plot_ann_embed_tsne_sp,
        plot_ann_embed_tsne_mx,
        plot_raw_embed_tsne_ct,
        plot_raw_embed_tsne_sp,
        plot_raw_embed_tsne_mx,
    ],
    **compose_config
)

GridBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(wi…

In [11]:
plot_ann_embed_tsne_ct.selection(df_ann_embed_tsne.query(
    'cellType == "CD4+CD8-CD3+CD45RA-CD27+CD19-CD103-CD28+CD69-PD1+HLADR+GranzymeB-CD25+ICOS+TCRgd-CD38+CD127-Tim3+"'
).index)

<jscatter.jscatter.Scatter at 0x1bc3aed00>