In [None]:
import os
from collections import defaultdict
from math import ceil

import anndata
import faiss
import numpy as np
import pandas as pd
import plotly.io as pio
import scanpy as sc
from matplotlib import rcParams

import scglue

In [None]:
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (8, 8)

PATH = "s06_sankey"
os.makedirs(PATH, exist_ok=True)

# Read data

In [None]:
rna = anndata.read_h5ad("s04_glue_final/full/rna.h5ad", backed="r")
atac = anndata.read_h5ad("s04_glue_final/full/atac.h5ad", backed="r")

In [None]:
atac.obs["NNLS"] = atac.obs["cell_type"]

# Transfer labels

In [None]:
rna_latent = rna.obsm["X_glue"]
atac_latent = atac.obsm["X_glue"]
rna_latent = rna_latent / np.linalg.norm(rna_latent, axis=1, keepdims=True)
atac_latent = atac_latent / np.linalg.norm(atac_latent, axis=1, keepdims=True)

In [None]:
np.random.seed(0)

quantizer = faiss.IndexFlatIP(rna_latent.shape[1])
n_voronoi = round(np.sqrt(rna_latent.shape[0]))
index = faiss.IndexIVFFlat(quantizer, rna_latent.shape[1], n_voronoi, faiss.METRIC_INNER_PRODUCT)
index.train(rna_latent[np.random.choice(rna_latent.shape[0], 50 * n_voronoi, replace=False)])
index.add(rna_latent)

# index = faiss.IndexFlatIP(rna_latent.shape[1])
# index.add(rna_latent)

In [None]:
nnd, nni = index.search(atac_latent, 50)

In [None]:
hits = rna.obs["cell_type"].to_numpy()[nni]

In [None]:
pred = pd.crosstab(
    np.repeat(atac.obs_names, nni.shape[1]), hits.ravel()
).idxmax(axis=1).loc[atac.obs_names]
pred = pd.Categorical(pred, categories=rna.obs["cell_type"].cat.categories)
atac.obs["GLUE"] = pred

In [None]:
atac.write(f"{PATH}/atac_transferred.h5ad", compression="gzip")
# atac = anndata.read_h5ad(f"{PATH}/atac_transferred.h5ad")

# Sankey

In [None]:
COLOR_MAP = {
    k: v for k, v in
    zip(atac.obs["cell_type"].cat.categories, atac.uns["cell_type_colors"])
}
link_cutoff = ceil(atac.shape[0] * 0.001)
link_color_map = defaultdict(lambda: "#CCCCCC")
link_color_map.update({
    ("Astrocytes", "Excitatory neurons"): COLOR_MAP["Excitatory neurons"],
    ("Astrocytes/Oligodendrocytes", "Astrocytes"): COLOR_MAP["Astrocytes"],
    ("Astrocytes/Oligodendrocytes", "Oligodendrocytes"): COLOR_MAP["Oligodendrocytes"]
})
fig = scglue.plot.sankey(
    atac.obs["NNLS"],
    atac.obs["GLUE"],
    title="NNLS vs GLUE transferred labels",
    left_color=lambda x: COLOR_MAP[x],
    right_color=lambda x: COLOR_MAP[x],
    link_color=lambda x: "rgba(0.9,0.9,0.9,0.2)" if x["value"] <= link_cutoff \
        else link_color_map[(x["left"], x["right"])],
    width=700, height=1400, font_size=14
)
pio.write_image(fig, f"{PATH}/sankey.png", scale=10)

# Accuracy

## Exact match

In [None]:
match_set = {(item, item) for item in atac.obs["NNLS"].cat.categories}

In [None]:
match = np.array([(i, j) in match_set for i, j in zip(atac.obs["NNLS"], atac.obs["GLUE"])])
np.sum(match) / atac.shape[0]

## Relaxed match

In [None]:
for item in atac.obs["NNLS"].cat.categories:
    if "?" in item:
        match_set.add((item, item.replace("?", "")))
    if "/" in item:
        for split in item.split("/"):
            match_set.add((item, split))
match_set = match_set.union({
    ("Syncytiotrophoblast and villous cytotrophoblasts?", "Syncytiotrophoblasts and villous cytotrophoblasts"),
#     ("Thymocytes", "Lympoid cells"),
#     ("Myeloid cells", "Microglia"),
#     ("Astrocytes", "Excitatory neurons")
})

In [None]:
match = np.array([(i, j) in match_set for i, j in zip(atac.obs["NNLS"], atac.obs["GLUE"])])
mask = ~atac.obs["NNLS"].str.contains("unknown", case=False)
np.sum(np.logical_and(match, mask)) / mask.sum()

In [None]:
NNLS_size = atac.obs["NNLS"].value_counts().to_dict()
GLUE_size = atac.obs["GLUE"].value_counts().to_dict()

In [None]:
unmatch = atac.obs.loc[
    np.logical_and(~match, mask), ["NNLS", "GLUE"]
].value_counts()
unmatch.name = "count"
unmatch = unmatch.reset_index()
unmatch["NNLS_size"] = unmatch["NNLS"].map(NNLS_size)
unmatch["GLUE_size"] = unmatch["GLUE"].map(GLUE_size)
unmatch["frac_all"] = unmatch["count"] / atac.shape[0]
unmatch["frac_NNLS"] = unmatch["count"] / unmatch["NNLS_size"]
unmatch["frac_GLUE"] = unmatch["count"] / unmatch["GLUE_size"]
unmatch.head(n=10)

In [None]:
unmatch.to_csv(f"{PATH}/unmatch.csv", index=False)