# Case Study: Haematopoietic Stem Cell

In this case study, we will use the Tusi dataset as reference and try predicting continuous cell fate between
1. Different sequencing runs within the Tusi dataset
2. Different species, i.e. mouse (Tusi) and human (Velten)

## Preparation

In [None]:
import json
import collections
import colorsys
import functools
import os
import subprocess
import sys

import numpy as np
import pandas as pd
import scipy.spatial
import scipy.stats
import sklearn.preprocessing
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import scvi.dataset
import scvi.models
import scvi.inference
import scvi.inference.annotation
import Cell_BLAST as cb

sys.path.insert(0, "../../../Evaluation")
import utils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
cb.config.N_JOBS = 4
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"

In [None]:
FATES = ["E", "Ba", "Meg", "Ly", "D", "M", "G"]
COLORS = np.array([
#     330, 300, 30, 190, 165, 130, 260
    0.0, 0.14, 0.23, 0.48, 0.58, 0.74, 0.88
])  # Hue in HSV

In [None]:
hsv2hex = lambda x, na="#474747": ("#%02x%02x%02x" % tuple(
    int(i * 255) for i in colorsys.hsv_to_rgb(x[0], x[1], x[2])
)) if not np.any(np.isnan(x)) else na


def jsd(p, q):
    m = 0.5 * (p + q)
    return 0.5 * (scipy.stats.entropy(p, m) + scipy.stats.entropy(q, m))


def plot_fate(ds, fates, colors, method="SPRING", size=3, width=4.5, height=4.5, sort=False, na="#474747"):
    mask = ~np.any(np.isnan(ds.obs.loc[:, fates].values), axis=1)
    ds.obs["entropy"] = np.nan
    ds.obs.loc[mask, "entropy"] = np.sum(
        -ds.obs.loc[mask, fates] * np.log(ds.obs.loc[mask, fates]), axis=1)
    h = colors[ds.obs.loc[:, fates].values.argmax(axis=1)]
    s = (np.log(len(fates)) - ds.obs["entropy"].values) / (np.log(len(fates)))
    s = 1 - (s - 1) ** 2
    v = np.repeat(0.85, ds.shape[0])
    hsv = np.stack([h, s, v], axis=1)
    ds.obs["color"] = np.apply_along_axis(hsv2hex, axis=1, arr=hsv, na=na)
    
    fig, ax = plt.subplots(figsize=(width, height))
    order = np.argsort(
        ds.obs["color"] != na
    ) if sort else np.arange(ds.shape[0])
    _ = plt.scatter(
        x=ds.obs[method + "1"][order],
        y=ds.obs[method + "2"][order],
        faceCOLORS=ds.obs["color"][order],
        s=size, edgecolor=None, rasterized=True
    )
    _ = plt.legend(title="Lineage", handles=[
        mpatches.Patch(color=hsv2hex((c, 1.0, 0.85)), label=f)
        for c, f in zip(colors, fates)
    ], frameon=False, bbox_to_anchor=(1.05, 0.5), loc="center left")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.yaxis.set_ticks_position("left")
    ax.xaxis.set_ticks_position("bottom")
    plt.xlabel(method + "1")
    plt.ylabel(method + "2")
    return ax


def fate_marker_correlation(ds, fates, marker_dict, width=15, height=20, plot=True):
    nrow, ncol = len(fates), max([len(item) for item in marker_dict.values()])
    if plot:
        fig, axes = plt.subplots(
            nrow, ncol, figsize=(width, height),
            gridspec_kw=dict(wspace=0.4, hspace=0.4)
        )
    else:
        fig = None
    ds_marker_only = ds.normalize()[
        :, functools.reduce(np.union1d, marker_dict.values())]
    ds_marker_only.exprs = ds_marker_only.exprs.log1p().toarray()
    rhos = -np.ones((nrow, ncol)) * np.inf
    for i in range(nrow):
        fate = fates[i]
        x = ds_marker_only.obs.loc[:, fate].values.ravel()
        for j in range(ncol):
            if j >= len(marker_dict[fate]):
                axes[i, j].axis("off")
            else:
                marker = marker_dict[fate][j]
                y = ds_marker_only[:, marker].exprs.ravel()
                rhos[i, j] = scipy.stats.spearmanr(
                    x, y
                ).correlation
                if plot:
                    ax = sns.scatterplot(
                        x=x, y=y, edgecolor=None, s=2,
                        rasterized=True, ax=axes[i, j]
                    )
                    ax.set_title("ρ = %.3f" % rhos[i, j], y=0.8, fontsize=10)
                    ax.set_xlabel(fate)
                    ax.set_ylabel(marker)
    return fig, rhos


def plot_markers(ds, marker_dict, method="SPRING", width=15, height=20):
    nrow, ncol = len(marker_dict), max([len(item) for item in marker_dict.values()])
    fig, axes = plt.subplots(
        nrow, ncol, figsize=(width, height),
        gridspec_kw=dict(wspace=0.2, hspace=0.2)
    )
    for i, fate in enumerate(marker_dict.keys()):
        for j in range(ncol):
            if j >= len(marker_dict[fate]):
                axes[i, j].axis("off")
            else:
                marker = marker_dict[fate][j]
                ax = ds.visualize_latent(
                    marker, method=method, scatter_kws=dict(rasterized=True),
                    shuffle=False, sort=True, ax=axes[i, j]
                )
                ax.set_title(marker, fontsize=10)
                ax.get_legend().remove()
                ax.xaxis.set_visible(False)
                if j == 0:
                    ax.set_ylabel(fate)
                    ax.yaxis.labelpad = 20
                    ax.yaxis.set_ticks([])
                else:
                    ax.yaxis.set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.spines["left"].set_visible(False)
    return fig


@torch.no_grad()
def get_scanvi_class_posterior(scanvi_trainer):
    annotation_posterior = scanvi_trainer.create_posterior()
    scanvi_trainer.model.eval()
    class_posterior = []
    for sample_batch, _, _, _, _ in annotation_posterior:
        class_posterior.append(scanvi_trainer.model.classify(sample_batch))
    return torch.cat(class_posterior).cpu().numpy()

## Tusi

In [None]:
tusi = cb.data.ExprDataSet.read_dataset("../../../Datasets/data/Tusi/data.h5")
tusi.obs["Sequencing run"] = "run 2"
tusi.obs.loc[tusi.obs["batch"] == "basal_bm1", "Sequencing run"] = "run 1"
tusi.obs["Sequencing run"] = pd.Categorical(tusi.obs["Sequencing run"])
tusi.obs["Batch"] = pd.Categorical(tusi.obs["batch"])

tusi.obs.loc[:, FATES] = tusi.obs.loc[:, FATES].clip(lower=0)
tusi.obs.loc[:, FATES] = tusi.obs.loc[:, FATES].div(
    tusi.obs.loc[:, FATES].sum(axis=1), axis=0)
tusi.obs["Discrete fate"] = np.array(FATES)[tusi.obs.loc[:, FATES].to_numpy().argmax(axis=1)]

### Inspect dataset

#### Latent space

In [None]:
tusi_model = cb.directi.fit_DIRECTi(
    tusi, genes=tusi.uns["seurat_genes"],
    latent_dim=10, cat_dim=None,
    epoch=300, patience=50, random_seed=0
)

In [None]:
tusi.latent = tusi_model.inference(tusi)

In [None]:
dist_mat = scipy.spatial.distance.squareform(
    scipy.spatial.distance.pdist(tusi.latent))
np.save("expr.npy", tusi.obs.loc[:, FATES].values)
np.save("dist.npy", dist_mat, allow_pickle=True)
np.save("gene.npy", np.array(FATES), allow_pickle=True)

In [None]:
%%bash
eval "$(conda shell.bash hook 2> /dev/null)"  # May be unnecessary depending on the environment setup

# git clone git@github.com:AllonKleinLab/SPRING.git
# conda create -n SPRING python=2.7 numpy scipy scikit-learn matplotlib jupyter

conda activate envs/SPRING
python prep_spring.py -e "expr.npy" -d "dist.npy" -g "gene.npy" -k 5 -o "SPRING/datasets/Tusi"

# Please start the server elsewhere by running: `python -m SimpleHTTPServer 8000` under directory SPRING and conda environment SPRING
# Now access SPRING web server and adjust the plot: <ip_address>:8000/springViewer.html?datasets/Tusi
# Save the final coordinate to "SPRING/datasets/Tusi/coordinates.txt"

In [None]:
coordinates = pd.read_csv(
    "SPRING/datasets/Tusi/coordinates.txt", index_col=0, header=None
).values
tusi.obs["SPRING1"] = coordinates[:, 0]
tusi.obs["SPRING2"] = coordinates[:, 1]

In [None]:
ax = tusi.visualize_latent("Batch", method="SPRING", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("tusi_batch.pdf", dpi=300, bbox_inches="tight")

In [None]:
ax = tusi.visualize_latent("Sequencing run", method="SPRING", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("tusi_seq_run.pdf", dpi=300, bbox_inches="tight")

In [None]:
ax = tusi.visualize_latent("Discrete fate", method="SPRING", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("tusi_discrete_fate.pdf", dpi=300, bbox_inches="tight")

In [None]:
ax = plot_fate(tusi, FATES, COLORS, "SPRING")
# ax.get_figure().savefig("tusi_fate.pdf", dpi=300, bbox_inches="tight")

#### Fate-marker correlation

In [None]:
with open("tusi_marker.json", "r") as f:
    tusi_markers = json.load(f, object_pairs_hook=collections.OrderedDict)

In [None]:
fig = plot_markers(tusi, tusi_markers, width=12)
fig.savefig("tusi_marker.pdf", dpi=300, bbox_inches="tight")

In [None]:
fig, rhos = fate_marker_correlation(tusi, FATES, tusi_markers, width=14)
fig.savefig("tusi_fate_marker.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos.max(axis=1).mean())

### Within dataset BLAST

In [None]:
ref = tusi[tusi.obs["Sequencing run"] == "run 2", :]
query = tusi[tusi.obs["Sequencing run"] == "run 1", :]
ref.write_dataset("./tusi_run2.h5")
query.write_dataset("./tusi_run1.h5")

In [None]:
fate_pred = {}
fate_pred_jsd = {}

#### Cell BLAST

In [None]:
within_dataset_blast_models = []
for i in range(4):
    print("==== Training model %d ====" % i)
    within_dataset_blast_models.append(cb.directi.fit_DIRECTi(
        ref, genes=ref.uns["seurat_genes"],
        latent_dim=10, cat_dim=None,
        epoch=300, patience=50, random_seed=i
    ))

In [None]:
blast = cb.blast.BLAST(within_dataset_blast_models, ref)

In [None]:
hits = blast.query(query).reconcile_models().filter("pval", 0.05)

In [None]:
fate_pred["cb"] = pd.concat([hits.annotate(fate, min_hits=1) for fate in FATES], axis=1)

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cb"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
fate_pred_jsd["cb"] = np.array([
    jsd(p, q) for p, q in
    zip(fate_pred["cb"].values[~reject_mask], query.obs.loc[:, FATES].values[~reject_mask])
])
assert np.isinf(fate_pred_jsd["cb"]).sum() == 0

In [None]:
ax = sns.distplot(
    fate_pred_jsd["cb"],
    bins=np.linspace(0, 0.7, 50),
    axlabel="Fate prediction JSD (Cell BLAST)"
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

#### scmap

In [None]:
p = subprocess.Popen([
    "Rscript", "./run_scmap.R",
    "-r", "../Notebooks/Case/HSC/tusi_run2.h5", "-n", "ref",
    "-q", "../Notebooks/Case/HSC/tusi_run1.h5",
    "-o", "../Notebooks/Case/HSC/tusi_within_dataset_scmap.h5",
    "-g", "scmap_genes", "-s", "0", "--cluster-col", "organ"  # just sham prediction
], cwd="../../../Evaluation", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
scmap_idx = cb.data.read_hybrid_path("tusi_within_dataset_scmap.h5//scmap_cell/nn/ref/cells").astype(int)
scmap_sim = cb.data.read_hybrid_path("tusi_within_dataset_scmap.h5//scmap_cell/nn/ref/similarities")

In [None]:
fate_pred["scmap"] = pd.DataFrame([
    ref.obs.iloc[_idx[_sim >= 0.5] - 1, :].loc[:, FATES].mean(axis=0)
    for _idx, _sim in zip(scmap_idx, scmap_sim)
])  # idx - 1 because R idx is 1 based

In [None]:
reject_mask = np.any(np.isnan(fate_pred["scmap"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
fate_pred_jsd["scmap"] = np.array([
    jsd(p, q) for p, q in
    zip(fate_pred["scmap"].values[~reject_mask], query.obs.loc[:, FATES].values[~reject_mask])
])
assert np.isinf(fate_pred_jsd["scmap"]).sum() == 0

In [None]:
ax = sns.distplot(
    fate_pred_jsd["scmap"],
    bins=np.linspace(0, 0.7, 50),
    axlabel="Fate prediction JSD (scmap)"
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

#### CellFishing.jl

In [None]:
p = subprocess.Popen([
    "julia", "./run_CellFishing.jl.jl",
    "--annotation=organ", # just sham prediction
    "--gene=cf_genes", "--seed=0", "--cutoff=120",
    "../Notebooks/Case/HSC/tusi_run2.h5",
    "../Notebooks/Case/HSC/tusi_run1.h5",
    "../Notebooks/Case/HSC/tusi_within_dataset_cf.h5"
], cwd="../../../Evaluation", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
cf_idx = cb.data.read_hybrid_path("tusi_within_dataset_cf.h5//indexes").astype(int)
cf_sim = cb.data.read_hybrid_path("tusi_within_dataset_cf.h5//hammingdistances")

In [None]:
fate_pred["cf"] = pd.DataFrame([
    ref.obs.iloc[_idx[_sim <= 120] - 1, :].loc[:, FATES].mean(axis=0)
    for _idx, _sim in zip(cf_idx, cf_sim)
])

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cf"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
fate_pred_jsd["cf"] = np.array([
    jsd(p, q) for p, q in
    zip(fate_pred["cf"].values[~reject_mask], query.obs.loc[:, FATES].values[~reject_mask])
])
assert np.isinf(fate_pred_jsd["cf"]).sum() == 0

In [None]:
ax = sns.distplot(
    fate_pred_jsd["cf"],
    bins=np.linspace(0, 0.7, 50),
    axlabel="Fate prediction JSD (CellFishing.jl)"
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

#### CCA anchor

In [None]:
p = subprocess.Popen([
    "Rscript", "./cca_anchor_transfer.R",
    "-r", "HSC/tusi_run2.h5",
    "-q", "HSC/tusi_run1.h5",
    "-g", "seurat_genes",
    "-o", "HSC/tusi_within_dataset_cca_anchor.h5",
    "-a", *FATES, "-d", "20", "-s", "0"
], cwd="..", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
fate_pred["cca_anchor"] = cb.data.read_hybrid_path("tusi_within_dataset_cca_anchor.h5//prediction")

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cca_anchor"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
fate_pred_jsd["cca_anchor"] = np.array([
    jsd(p, q) for p, q in
    zip(fate_pred["cca_anchor"][~reject_mask], query.obs.loc[:, FATES].values[~reject_mask])
])
assert np.isinf(fate_pred_jsd["cca_anchor"]).sum() == 0

In [None]:
ax = sns.distplot(
    fate_pred_jsd["cca_anchor"],
    bins=np.linspace(0, 0.7, 50),
    axlabel="Fate prediction JSD (CCA anchor)"
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

#### scANVI

In [None]:
label_mapping = {item: idx for idx, item in enumerate(FATES)}
labels = np.vectorize(lambda x: label_mapping[x])(tusi.obs["Discrete fate"])
labelled_indices = np.where(np.logical_and(
    tusi.obs["Sequencing run"] == "run 2",
    tusi.obs.loc[:, FATES].max(axis=1) > 0.5
))[0]

In [None]:
tusi[:, tusi.uns["seurat_genes"]].to_anndata().write_h5ad("data.h5ad")
tusi_adata = scvi.dataset.AnnDataset("data.h5ad", save_path="./")
tusi_adata.labels, tusi_adata.n_labels = labels.reshape((-1, 1)), np.unique(labels).size

In [None]:
np.random.seed(0)
torch.manual_seed(0)
vae = scvi.models.SCANVI(
    tusi_adata.nb_genes, n_labels=tusi_adata.n_labels,
    n_latent=5, n_hidden=128, n_layers=1
)
trainer = scvi.inference.annotation.CustomSemiSupervisedTrainer(
    vae, tusi_adata, labelled_indices,
    use_cuda=True, metrics_to_monitor=["ll"], frequency=5,
    early_stopping_kwargs=dict(
        early_stopping_metric="ll", save_best_state_metric="ll",
        patience=30, threshold=0
    )
)

In [None]:
trainer.train(n_epochs=1000, lr=1e-3)

In [None]:
fate_pred["scanvi"] = get_scanvi_class_posterior(trainer)[tusi.obs["Sequencing run"] == "run 1", :]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["scanvi"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
fate_pred_jsd["scanvi"] = np.array([
    jsd(p, q) for p, q in
    zip(fate_pred["scanvi"][~reject_mask], query.obs.loc[:, FATES].values[~reject_mask])
])
assert np.isinf(fate_pred_jsd["scanvi"]).sum() == 0

In [None]:
ax = sns.distplot(
    fate_pred_jsd["scanvi"],
    bins=np.linspace(0, 0.7, 50),
    axlabel="Fate prediction JSD (scANVI)"
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

### Compare JSD between true fate and predicted fate

In [None]:
with open("../../../Evaluation/palette_method.json", "r") as f:
    palette = json.load(f)

In [None]:
bins = np.linspace(0, 0.7, 50)
fig, ax = plt.subplots(figsize=(4.5, 4.5))
ax2 = ax.twinx()
ax = sns.distplot(
    fate_pred_jsd["cb"], color=palette["Cell BLAST"],
    kde=False, hist_kws=dict(density=True, alpha=0.5), bins=bins,
    axlabel="Fate prediction JSD", ax=ax
)
ax2 = sns.distplot(
    fate_pred_jsd["cb"], color=palette["Cell BLAST"],
    hist=False, kde_kws=dict(cumulative=True),
    axlabel="Fate prediction JSD", ax=ax2
)
ax = sns.distplot(
    fate_pred_jsd["cf"], color=palette["CellFishing.jl"],
    kde=False, hist_kws=dict(density=True, alpha=0.5), bins=bins,
    axlabel="Fate prediction JSD", ax=ax
)
ax2 = sns.distplot(
    fate_pred_jsd["cf"], color=palette["CellFishing.jl"],
    hist=False, kde_kws=dict(cumulative=True),
    axlabel="Fate prediction JSD", ax=ax2
)
ax = sns.distplot(
    fate_pred_jsd["cca_anchor"], color=palette["CCA anchor"],
    kde=False, hist_kws=dict(density=True, alpha=0.5), bins=bins,
    axlabel="Fate prediction JSD", ax=ax
)
ax2 = sns.distplot(
    fate_pred_jsd["cca_anchor"], color=palette["CCA anchor"],
    hist=False, kde_kws=dict(cumulative=True),
    axlabel="Fate prediction JSD", ax=ax2
)
ax = sns.distplot(
    fate_pred_jsd["scmap"], color=palette["scmap"],
    kde=False, hist_kws=dict(density=True, alpha=0.5), bins=bins,
    axlabel="Fate prediction JSD", ax=ax
)
ax2 = sns.distplot(
    fate_pred_jsd["scmap"], color=palette["scmap"],
    hist=False, kde_kws=dict(cumulative=True),
    axlabel="Fate prediction JSD", ax=ax2
)
ax = sns.distplot(
    fate_pred_jsd["scanvi"], color=palette["scANVI"],
    kde=False, hist_kws=dict(density=True, alpha=0.5), bins=bins,
    axlabel="Fate prediction JSD", ax=ax
)
ax2 = sns.distplot(
    fate_pred_jsd["scanvi"], color=palette["scANVI"],
    hist=False, kde_kws=dict(cumulative=True),
    axlabel="Fate prediction JSD", ax=ax2
)
ax.set_ylabel("Density")
ax2.set_ylabel("Cumulative probability")
_ = plt.legend(handles=[
    mpatches.Patch(color=c, label=l)
    for c, l in map(
        lambda x: (palette[x], x),
        ["scANVI", "scmap", "CCA anchor", "CellFishing.jl", "Cell BLAST"]
    )
], frameon=False, bbox_to_anchor=(0.97, 0.03), loc="lower right")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(-0.03, 0.4)
fig.savefig("./tusi_jsd_cmp.pdf", bbox_inches="tight")

## Velten

In [None]:
velten = cb.data.ExprDataSet.read_dataset("../../../Datasets/data/Velten_Smart-seq2/data.h5")

In [None]:
human2mouse = pd.read_csv(
    "../../../Datasets/ortholog/Ensembl/orthology/9606_10090.csv", header=None)
velten2mouse = velten.map_vars(
    human2mouse.iloc[:, [1, 3]], map_uns_slots=["seurat_genes"])
velten2mouse.write_dataset("./velten2mouse.h5")

In [None]:
tv = cb.data.ExprDataSet.merge_datasets(dict(
    tusi=tusi, velten=velten2mouse
), merge_uns_slots=["seurat_genes"])

In [None]:
vFATES = ["E", "Ba", "Meg", "Ly", "M/D", "G"]
vCOLORS = np.array([
#     330, 300, 30, 190, 148, 260
    0.0, 0.14, 0.23, 0.48, 0.66, 0.88
])  # Hue in HSV

In [None]:
with open("./velten_marker.json", "r") as f:
    velten_markers = json.load(f, object_pairs_hook=collections.OrderedDict)

In [None]:
fate_pred = {}
rhos = {}

### Cross species BLAST

#### Cell BLAST

##### Not aligned

In [None]:
cross_species_blast_models = [tusi_model]
for i in range(1, 4):
    print("==== Training model %d ====" % i)
    cross_species_blast_models.append(cb.directi.fit_DIRECTi(
        tusi, genes=tusi.uns["seurat_genes"],
        latent_dim=10, cat_dim=None,
        epoch=300, patience=50, random_seed=i
    ))

In [None]:
blast = cb.blast.BLAST(cross_species_blast_models, tusi, eps=0.5)
blast.save("./tusi_blast")

In [None]:
hits = blast.query(velten2mouse).reconcile_models().filter("pval", 0.05)

In [None]:
fate_pred["cb"] = pd.concat([hits.annotate(fate, min_hits=1) for fate in FATES], axis=1)
fate_pred["cb"]["M/D"] = fate_pred["cb"]["M"] + fate_pred["cb"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cb"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

##### Aligned

In [None]:
blast = cb.blast.BLAST.load("./tusi_blast")

In [None]:
path = "/tmp/cb/%s" % cb.utils.rand_hex()
print("Aligning BLAST at %s..." % path)
blast_aligned = blast.align(
    velten2mouse, path=path
)

In [None]:
hits = blast_aligned.query(velten2mouse).reconcile_models().filter("pval", 0.05)

In [None]:
fate_pred["cb"] = pd.concat([hits.annotate(fate, min_hits=1) for fate in FATES], axis=1)
fate_pred["cb"]["M/D"] = fate_pred["cb"]["M"] + fate_pred["cb"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cb"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
tv.latent = blast_aligned.models[0].inference(tv)

In [None]:
dist_mat = scipy.spatial.distance.squareform(
    scipy.spatial.distance.pdist(tv.latent))
np.save("expr.npy", tv.obs.loc[:, FATES].fillna(0).values)
np.save("dist.npy", dist_mat, allow_pickle=True)
np.save("gene.npy", np.array(FATES), allow_pickle=True)

In [None]:
%%bash
eval "$(conda shell.bash hook 2> /dev/null)"  # May be unnecessary depending on the environment setup

# git clone git@github.com:AllonKleinLab/SPRING.git
# conda create -n SPRING python=2.7 numpy scipy scikit-learn matplotlib jupyter

conda activate envs/SPRING
python prep_spring.py -e "expr.npy" -d "dist.npy" -g "gene.npy" -k 4 -o "SPRING/datasets/TV"

# Please start the server elsewhere by running: `python -m SimpleHTTPServer 8000` under directory SPRING and conda environment SPRING
# Now access SPRING web server and adjust the plot: <ip_address>:8000/springViewer.html?datasets/TV
# Save the final coordinate to "SPRING/datasets/Tusi/coordinates.txt"

In [None]:
coordinates = pd.read_csv("SPRING/datasets/TV/coordinates.txt", index_col=0, header=None).values
tv.obs["SPRING1"] = coordinates[:, 0]
tv.obs["SPRING2"] = coordinates[:, 1]

In [None]:
ax = plot_fate(tv, FATES, COLORS, "SPRING")
ax.get_figure().savefig("velten_superimpose.pdf", dpi=300, bbox_inches="tight")

Visualize prediction

In [None]:
for column in velten.obs:
    if column in ["SPRING1", "SPRING2"] + FATES + vFATES:
        del velten.obs[column]
velten.obs = velten.obs.merge(
    tv.obs.loc[:, ["SPRING1", "SPRING2"]],
    left_index=True, right_index=True, how="left"
)
velten.obs = velten.obs.merge(
    fate_pred["cb"],
    left_index=True, right_index=True, how="left"
)

In [None]:
ax = plot_fate(velten, vFATES, vCOLORS, method="SPRING", size=15, sort=True, na="#FFFFFF")
ax.get_figure().savefig("velten_fate_cb.pdf", dpi=300, bbox_inches="tight")

Marker expression

In [None]:
fig = plot_markers(velten, velten_markers, width=22)
fig.savefig("velten_marker.pdf", dpi=300, bbox_inches="tight")

Fate-marker correlation

In [None]:
velten_use = velten[~reject_mask, :]
fig, rhos["cb"] = fate_marker_correlation(velten_use, vFATES, velten_markers, width=24)
fig.savefig("./velten_fate_marker_cb.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos["cb"].max(axis=1).mean())

#### scmap

In [None]:
p = subprocess.Popen([
    "Rscript", "./run_scmap.R",
    "-r", "../Datasets/data/Tusi/data.h5", "-n", "Tusi",
    "-q", "../Notebooks/Case/HSC/velten2mouse.h5",
    "-o", "../Notebooks/Case/HSC/velten_cross_species_scmap.h5",
    "-g", "scmap_genes", "-s", "0", "--cluster-col", "organ"
], cwd="../../../Evaluation", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
scmap_idx = cb.data.read_hybrid_path("velten_cross_species_scmap.h5//scmap_cell/nn/Tusi/cells").astype(int)
scmap_sim = cb.data.read_hybrid_path("velten_cross_species_scmap.h5//scmap_cell/nn/Tusi/similarities")

In [None]:
fate_pred["scmap"] = pd.DataFrame([
    tusi.obs.iloc[_idx[_sim >= 0.5] - 1, :].loc[:, FATES].mean(axis=0)
    for _idx, _sim in zip(scmap_idx, scmap_sim)
])  # idx - 1 because R idx is 1 based
fate_pred["scmap"]["M/D"] = fate_pred["scmap"]["M"] + fate_pred["scmap"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["scmap"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
velten.obs.loc[:, vFATES] = fate_pred["scmap"].loc[:, vFATES].values

In [None]:
ax = plot_fate(velten, vFATES, vCOLORS, method="SPRING", size=15, sort=True, na="#FFFFFF")
ax.get_figure().savefig("velten_fate_scmap.pdf", dpi=300, bbox_inches="tight")

In [None]:
velten_use = velten[~reject_mask, :]
fig, rhos["scmap"] = fate_marker_correlation(velten_use, vFATES, velten_markers, width=24)
fig.savefig("./velten_fate_marker_scmap.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos["scmap"].max(axis=1).mean())

#### CellFishing.jl

In [None]:
p = subprocess.Popen([
    "julia", "./run_CellFishing.jl.jl",
    "--annotation=organ", # just sham prediction
    "--gene=cf_genes", "--seed=0",
    "../Datasets/data/Tusi/data.h5",
    "../Notebooks/Case/HSC/velten2mouse.h5",
    "../Notebooks/Case/HSC/velten_cross_species_cf.h5"
], cwd="../../../Evaluation", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
cf_idx = cb.data.read_hybrid_path("velten_cross_species_cf.h5//indexes").astype(int)
cf_sim = cb.data.read_hybrid_path("velten_cross_species_cf.h5//hammingdistances")

In [None]:
fate_pred["cf"] = pd.DataFrame([
    tusi.obs.iloc[_idx[_sim <= 170] - 1, :].loc[:, FATES].mean(axis=0)
    for _idx, _sim in zip(cf_idx, cf_sim)
])
fate_pred["cf"]["M/D"] = fate_pred["cf"]["M"] + fate_pred["cf"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cf"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
velten.obs.loc[:, vFATES] = fate_pred["cf"].loc[:, vFATES].values

In [None]:
ax = plot_fate(velten, vFATES, vCOLORS, method="SPRING", size=20, sort=True, na="#FFFFFF")
ax.get_figure().savefig("velten_fate_cf.pdf", dpi=300, bbox_inches="tight")

In [None]:
velten_use = velten[~reject_mask, :]
fig, rhos["cf"] = fate_marker_correlation(velten_use, vFATES, velten_markers, width=24)
fig.savefig("./velten_fate_marker_cf.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos["cf"].max(axis=1).mean())

#### CCA anchor

In [None]:
p = subprocess.Popen([
    "Rscript", "./cca_anchor_transfer.R",
    "-r", "../../Datasets/data/Tusi/data.h5",
    "-q", "HSC/velten2mouse.h5",
    "-g", "seurat_genes",
    "-o", "HSC/velten_cross_species_cca_anchor.h5",
    "-a", *FATES, "-d", "20", "-s", "0"
], cwd="..", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(output, err) = p.communicate()
exit_code = p.wait()
print(output.decode())
print(err.decode())

In [None]:
fate_pred["cca_anchor"] = pd.DataFrame(
    cb.data.read_hybrid_path("velten_cross_species_cca_anchor.h5//prediction"),
    index=velten.obs_names, columns=FATES
)
fate_pred["cca_anchor"]["M/D"] = fate_pred["cca_anchor"]["M"] + fate_pred["cca_anchor"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["cca_anchor"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
velten.obs.loc[:, vFATES] = fate_pred["cca_anchor"].loc[:, vFATES].values

In [None]:
ax = plot_fate(velten, vFATES, vCOLORS, method="SPRING", size=20, sort=True, na="#FFFFFF")
ax.get_figure().savefig("velten_fate_cca_anchor.pdf", dpi=300, bbox_inches="tight")

In [None]:
velten_use = velten[~reject_mask, :]
fig, rhos["cca_anchor"] = fate_marker_correlation(velten_use, vFATES, velten_markers, width=24)
fig.savefig("./velten_fate_marker_cca_anchor.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos["cca_anchor"].max(axis=1).mean())

#### scANVI

In [None]:
combined_dataset = cb.data.ExprDataSet.merge_datasets({
    "tusi": tusi[:, tusi.uns["seurat_genes"]],
    "velten2mouse": velten2mouse[:, tusi.uns["seurat_genes"]]
})

In [None]:
label_mapping = {item: idx for idx, item in enumerate(FATES)}
label_mapping.update({np.nan: 0})  # Any class
labels = np.vectorize(lambda x: label_mapping[x])(combined_dataset.obs["Discrete fate"])
labelled_indices = np.where(np.logical_and(
    combined_dataset.obs["dataset_name"] == "Tusi",
    combined_dataset.obs.loc[:, FATES].max(axis=1) > 0.5
))[0]
batch_indices = sklearn.preprocessing.LabelEncoder().fit_transform(combined_dataset.obs["dataset_name"])

In [None]:
combined_dataset.to_anndata().write_h5ad("data.h5ad")
combined_adata = scvi.dataset.AnnDataset("data.h5ad", save_path="./")
combined_adata.labels, combined_adata.n_labels = \
    labels.reshape((-1, 1)), np.unique(labels).size
combined_adata.batch_indices, combined_adata.n_batches = \
    batch_indices.reshape((-1, 1)), np.unique(batch_indices).size

In [None]:
np.random.seed(0)
torch.manual_seed(0)
vae = scvi.models.SCANVI(
    combined_adata.nb_genes,
    n_labels=combined_adata.n_labels, n_batch=combined_adata.n_batches,
    n_latent=5, n_hidden=128, n_layers=1
)
trainer = scvi.inference.annotation.CustomSemiSupervisedTrainer(
    vae, combined_adata, labelled_indices,
    use_cuda=True, metrics_to_monitor=["ll"], frequency=5,
    early_stopping_kwargs=dict(
        early_stopping_metric="ll", save_best_state_metric="ll",
        patience=30, threshold=0
    )
)

In [None]:
trainer.train(n_epochs=1000, lr=1e-3)

In [None]:
fate_pred["scanvi"] = pd.DataFrame(
    get_scanvi_class_posterior(trainer)[
        combined_dataset.obs["dataset_name"] == "Velten_Smart-seq2", :
    ], index=velten2mouse.obs.index, columns=FATES)
fate_pred["scanvi"]["M/D"] = fate_pred["scanvi"]["M"] + fate_pred["scanvi"]["D"]

In [None]:
reject_mask = np.any(np.isnan(fate_pred["scanvi"]), axis=1)
print("Rejection rate = %.3f" % (reject_mask.sum() / reject_mask.size))

In [None]:
velten.obs.loc[:, vFATES] = fate_pred["scanvi"].loc[:, vFATES].values

In [None]:
ax = plot_fate(velten, vFATES, vCOLORS, method="SPRING", size=20, sort=True, na="#FFFFFF")
ax.get_figure().savefig("velten_fate_scanvi.pdf", dpi=300, bbox_inches="tight")

In [None]:
velten_use = velten[~reject_mask, :]
fig, rhos["scanvi"] = fate_marker_correlation(velten_use, vFATES, velten_markers, width=24)
fig.savefig("./velten_fate_marker_scanvi.pdf", dpi=300, bbox_inches="tight")
print("Overall correlation = %.3f" % rhos["scanvi"].max(axis=1).mean())

### Compare correlation

In [None]:
method_mapping = collections.OrderedDict(
    scmap="scmap",
    cf="CellFishing.jl",
    cca_anchor="CCA anchor",
    scanvi="scANVI",
    cb="Cell BLAST"
)

In [None]:
df = pd.DataFrame({
    key: pd.Series(np.apply_along_axis(
        lambda x: np.sort(x[np.isfinite(x)])[-3:].mean(),
        axis=1, arr=val
    ), index=vFATES)
    for key, val in rhos.items()
}).reset_index().rename(
    columns={"index": "Lineage", **method_mapping}
).melt(
    id_vars=["Lineage"],
    var_name="Method", value_name="Spearman's ρ"
)
df["Method"] = pd.Categorical(
    df["Method"], categories=method_mapping.values()
)

In [None]:
fig, ax = plt.subplots(figsize=(4.0, 4.0))
ax = sns.boxplot(
    x="Method", y="Spearman's ρ", data=df,
    linewidth=1.0, width=0.6, whis=10, ax=ax
)
for i, patch in enumerate(ax.artists):  # Make grey
    r, g, b, a = patch.get_edgecolor()
    patch.set_edgecolor((r, g, b, 0.2))
    _r, _g, _b, a = patch.get_facecolor()
    patch.set_facecolor((r, g, b, 0.1))
    for j in range(6 * i, 6 * (i + 1)):
        ax.lines[j].set_color((r, g, b, 0.2))
ax = sns.lineplot(
    x="Method", y = "Spearman's ρ", hue="Lineage", data=df,
    palette=[hsv2hex((c, 1.0, 0.85)) for c in vCOLORS], legend=False
)
ax = sns.scatterplot(
    x="Method", y = "Spearman's ρ", hue="Lineage", data=df,
    palette=[hsv2hex((c, 1.0, 0.85)) for c in vCOLORS], edgecolor=None
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xticklabels(ax.get_xticklabels(), rotation=20)
_ = plt.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), borderaxespad=0.0, frameon=False)
fig.savefig("velten_rho_cmp.pdf", bbox_inches="tight")