# Cluster annotation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scip_workflows.common import *


In [None]:
import anndata
import scanpy
import shap
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ConnectionStyle
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

from scip_workflows.core import plot_gate_czi

shap.initjs()


In [None]:
try:
    adata = snakemake.input.adata
    output_three = snakemake.output[0]
    output_cd15_cd45 = snakemake.output[1]
    output_cd15_siglec8 = snakemake.output[2]
    image_root = snakemake.input.image_root
except NameError:
    image_root = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800")
    data_dir = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800/scip/061020221736/")
    adata = data_dir / "adata.pickle"
    output_three = data_dir / "figures" / "cluster_panels.png"
    output_cd15_cd45 = data_dir / "figures" / "cd15_vs_cd45_facets.png"
    output_cd15_siglec8 = data_dir / "figures" / "cd15_vs_siglec8_facets.png"
    output_unclassified = data_dir / "figures" / "unclassified_cluster.png"


In [None]:
def map_names(a):
    return {
        "feat_combined_sum_DAPI": "DAPI",
        "feat_combined_sum_EGFP": "CD45",
        "feat_combined_sum_RPe": "Siglec 8",
        "feat_combined_sum_APC": "CD15",
    }[a]


In [None]:
with open(adata, "rb") as fh:
    adata = pickle.load(fh)


In [None]:
adata.obs.meta_path = adata.obs.meta_path.apply(
    lambda p: image_root.joinpath(*Path(p).parts[Path(p).parts.index("800") + 1 :])
)


In [None]:
markers = [
    col
    for col in adata.var_names
    if col.startswith(
        tuple("feat_combined_sum_%s" % m for m in ("EGFP", "RPe", "APC", "DAPI"))
    )
]


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="leiden",
    dendrogram=True,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[0],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden", legend_loc="on data", ax=axes[1], show=False)
seaborn.countplot(data=adata.obs, x="leiden", hue="meta_replicate", ax=axes[2])


In [None]:
adata.obs["leiden_merged"] = adata.obs.leiden.map(
    lambda a: a if a in [str(i) for i in [2, 4, 6, 8]] else "1"
)


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="leiden_merged",
    dendrogram=True,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[1],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden_merged", ax=axes[2], show=False)
seaborn.countplot(data=adata.obs, x="leiden_merged", hue="meta_replicate", ax=axes[0])


In [None]:
scanpy.pl.scatter(
    adata,
    x="feat_combined_sum_EGFP",
    y="feat_combined_sum_APC",
    color="leiden_merged",
    legend_loc="on data",
)


In [None]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(
        adata,
        keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC", "leiden_merged"],
        use_raw=True,
    ),
    col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(
        data=scanpy.get.obs_df(
            adata,
            keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC"],
            use_raw=True,
        ),
        x="feat_combined_sum_EGFP",
        y="feat_combined_sum_APC",
        color="grey",
        s=0.5,
        alpha=0.5,
        ax=ax,
    )
grid.map_dataframe(
    seaborn.scatterplot, x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel("CD45")
    ax.set_ylabel("CD15")

plt.savefig(output_cd15_cd45, bbox_inches="tight", pad_inches=0, dpi=200)


In [None]:
scanpy.pl.scatter(
    adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
    x="feat_combined_sum_RPe",
    y="feat_combined_sum_APC",
    color="leiden",
    legend_loc="on data",
)


In [None]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(
        adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
        keys=["feat_combined_sum_RPe", "feat_combined_sum_APC", "leiden_merged"],
        use_raw=True,
    ),
    col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(
        data=scanpy.get.obs_df(
            adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
            keys=["feat_combined_sum_RPe", "feat_combined_sum_APC"],
            use_raw=True,
        ),
        x="feat_combined_sum_RPe",
        y="feat_combined_sum_APC",
        color="grey",
        s=0.5,
        alpha=0.5,
        ax=ax,
    )
grid.map_dataframe(
    seaborn.scatterplot, x="feat_combined_sum_RPe", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel("Siglec 8")
    ax.set_ylabel("CD15")

plt.savefig(output_cd15_siglec8, bbox_inches="tight", pad_inches=0, dpi=200)


## SHAP

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    adata[:, adata.var.selected_corr],
    adata.obs["leiden_merged"],
    test_size=0.1,
    stratify=adata.obs["leiden_merged"],
)


In [None]:
model = RandomForestClassifier(n_estimators=50, random_state=0).fit(
    X_train.to_df(), y_train.values
)


In [None]:
preds = model.predict(X_test.to_df())
balanced_accuracy_score(y_test.values, preds)


In [None]:
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test.to_df())


In [None]:
y_train.cat.categories


In [None]:
shap.plots.beeswarm(shap_values[..., 3])


In [None]:
adata.obs["meta_masks"] = adata.obs[["meta_scene", "meta_tile"]].apply(
    lambda r: str(data_dir / "masks" / "%s_%s.npy") % (r.meta_scene, r.meta_tile),
    axis=1,
)


In [None]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
    masks_path_col="meta_masks",
)


In [None]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
)
plt.savefig(output_unclassified)


In [None]:
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05, 0.95])
extent = quantiles.loc[
    :,
    [
        "feat_combined_sum_%s" % s
        for s in ["DAPI", "EGFP", "RPe", "APC", "Bright", "Oblique", "PGC"]
    ],
].T.values


In [None]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
    extent=extent,
)


In [None]:
scanpy.pl.violin(adata, "feat_combined_sum_APC", groupby="leiden_merged")


In [None]:
shap.plots.scatter(shap_values[..., "feat_combined_sum_APC", 4])


In [None]:
shap.plots.beeswarm(shap_values[..., 5])


In [None]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "9",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=30,
    masks_path_col="meta_masks",
)


# Cluster annotation

In [None]:
# create a dictionary to map cluster to annotation label
cluster2annotation = {
    "1": "granulocytes",
    "8": "eosinophils",
    "4": "monocytes",
    "2": "lymphocytes",
    "6": "unclassified",
}

# add a new `.obs` column called `cell type` by mapping clusters to annotation using pandas `map` function
cat_type = pandas.CategoricalDtype(
    ["monocytes", "lymphocytes", "granulocytes", "eosinophils", "unclassified"],
    ordered=True,
)
adata.obs["cell type"] = (
    adata.obs["leiden_merged"].map(cluster2annotation).astype(cat_type)
)


In [None]:
fig = plt.figure(dpi=200, figsize=(10, 7), constrained_layout=True)
gs = GridSpec(2, 2, figure=fig)

ax = fig.add_subplot(gs[0, 0])
scanpy.pl.scatter(
    adata[adata.obs.leiden_merged.isin(cluster2annotation.keys())],
    x="feat_combined_sum_EGFP",
    y="feat_combined_sum_APC",
    color="leiden_merged",
    legend_loc="on data",
    ax=ax,
    show=False,
)
# ax.annotate('monocytes', xy=(.8, -.6), xytext=(3, -1.5), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
# ax.annotate('lymphocytes', xy=(1.2, -1.7), xytext=(2.5, -2.5), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
ax.text(
    s="A",
    x=0.02,
    y=1,
    fontsize=20,
    weight="heavy",
    alpha=0.2,
    transform=ax.transAxes,
    va="top",
)
ax.set_xlabel("CD45")
ax.set_ylabel("CD15")
ax.set_title("")
# ax.set_aspect(0.9)
ax.set_yticks([])
ax.set_xticks([])

# ax.annotate("", xytext=(2.5, 1.8), xy=(5, 2.5), arrowprops=dict(facecolor='grey', width=7, edgecolor="none", connectionstyle=ConnectionStyle("Arc3", rad=-0.2)))

ax2 = fig.add_subplot(gs[1, 0])
scanpy.pl.scatter(
    adata[adata.obs["cell type"].isin(["granulocytes", "eosinophils"])],
    x="feat_combined_sum_RPe",
    y="feat_combined_sum_APC",
    color="leiden_merged",
    legend_loc="on data",
    ax=ax2,
    show=False,
)
# ax2.annotate('eosinophils', xy=(2.15, -.15), xytext=(.7, -1.8), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
# ax2.annotate('neutrophils', xy=(-.75, 1), xytext=(-3, 1), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
ax2.text(
    s="B",
    x=0.02,
    y=1,
    fontsize=20,
    weight="heavy",
    alpha=0.2,
    transform=ax2.transAxes,
    va="top",
)
ax2.set_title("")
ax2.set_ylabel("CD15")
ax2.set_xlabel("Siglec 8")
# ax2.set_aspect(0.9)
ax2.set_yticks([])
ax2.set_xticks([])

ax3 = fig.add_subplot(gs[:, 1])
scanpy.pl.umap(
    adata,
    color=["cell type"],
    legend_loc="on data",
    ax=ax3,
    show=False,
    palette=plt.get_cmap("tab10")([8, 2, 4, 2]).tolist(),
)
ax3.text(
    s="C",
    x=0.02,
    y=1,
    fontsize=20,
    weight="heavy",
    alpha=0.2,
    transform=ax3.transAxes,
    va="top",
)
ax3.set_title("")
ax3.set_aspect(1)

seaborn.despine(fig)

plt.savefig(output, bbox_inches="tight", pad_inches=0, dpi=200)


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="cell type",
    dendrogram=False,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[1],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="cell type", ax=axes[2], show=False, palette="tab10")
seaborn.countplot(data=adata.obs, y="cell type", hue="meta_replicate", ax=axes[0])

axes[0].set_title("Cell type counts")
axes[1].set_title("Marker intensity")
axes[2].set_title("UMAP")
axes[0].legend(title="Replicate")

plt.savefig(output_three, bbox_inches="tight", pad_inches=0, dpi=200)


In [None]:
counts = adata.obs["cell type"].value_counts().to_frame()
counts["fraction"] = counts["cell type"] / counts["cell type"].sum()
counts.columns = ["Count", "Fraction"]
print(counts.style.to_latex())
