In [None]:
import sys
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import copy
import matplotlib.pyplot as plt
import warnings
sys.path.append("../")
from upsetplot import plot, from_contents
from scanpy.plotting.palettes import vega_20_scanpy
from stamarker.dataset import SpatialDataModule
from stamarker.pipeline import STAMarker, make_spatial_data
from stamarker.utils import parse_args, select_svgs
warnings.filterwarnings("ignore")

# Load data

In [None]:
ann_data = sc.read_h5ad("../dataset/rodriques_slideseq.h5ad")
print(ann_data)
data_module = make_spatial_data(ann_data)
data_module.prepare_data(rad_cutoff=40, n_top_genes=3000, min_counts=20)

# STAMarker training procedure

In [None]:
config = dict()
config.update(parse_args("_params/model.yaml"))
config.update(parse_args("_params/trainer.yaml"))
if not torch.cuda.is_available():
        config["stagate_trainer"]["gpus"] = None
        config["classifier_trainer"]["gpus"] = None

In [None]:
# initialize the `STAMarker` object
stamarker = STAMarker(20, "Mouse_cerebellum_output/", config)

In [None]:
# Train autoencoders
stamarker.train_auto_encoders(data_module)

One the the train 

In [None]:
stamarker.clustering(data_module, "louvain", 0.6)

Consensus cluster the labels into 5 classes

In [None]:
# Perform consensus clustering 
stamarker.consensus_clustering(5)

In [None]:
stamarker.train_classifiers(data_module, 5, consensus_labels_path="consensus_labels.npy")

In [None]:
smap = stamarker.compute_smaps(data_module)

# Visualize spatial domains

In [None]:
consensus_labels = np.load(stamarker.save_dir + "/consensus_labels.npy")
ann_data.obs["Consensus clustering"] = consensus_labels.astype(str)
n_class = np.max(consensus_labels) + 1
print("Num of spatial domains", n_class)

In [None]:
a = 1.5
fig, ax = plt.subplots(1, 1, figsize=(1.45 * a, 1.42 * a))
sc.pl.embedding(ann_data, basis="spatial", color = "Consensus clustering", show=False, ax=ax, s=6, 
                 palette=vega_20_scanpy, frameon=False)
ax.set_title("")
ax.set_aspect("equal")

In [None]:
domain_svg_list = []
for domain_ind in range(5):
    domain_svg_list.append(select_svgs(np.log(1 + smap), domain_ind, consensus_labels, alpha=1.25))
upset_domains_df = from_contents({ f"Spatial domain {ind}": l for ind, l in enumerate(domain_svg_list)})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2.7, 2.5))
df = pd.DataFrame(upset_domains_df.index.to_frame().apply(np.sum, axis=1))
df.columns = ["counts"]
df.index = upset_domains_df["id"]
df_counts = df.groupby("counts")["counts"].agg("count")
ax.bar(df_counts.index, df_counts.values)
ax.set_xticks(df_counts.index)
ax.set_xlabel("Number of spatial domains")
ax.set_ylabel("Number of genes")
plt.tight_layout()

In [None]:
morans_i_array = compute_morans_I(ann_data, data_module, version_dirs)
stmarker_domains_df = get_svg_domains_df(ann_data, upset_domains_df, morans_i_array)
stmarker_domains_df["Specific"] = np.isin(stmarker_domains_df.index, upset_df_methods.loc[(True, False, False, False)])

In [None]:
tgenes = "Syt11/Git2/Gria3/Unc13c/Myh10/Ptprz1/Timp2/Thy1/Nefl/Cplx1/Sncb/Nrsn1/Kcna1/Ucn/Gad1/App/Cck/Arpc2/Ntrk2/Clu/Dixdc1/Kif5c/Clasp2/Pcdh9/Map2/Calb1/Cplx2/Auts2/Grin1/Prnp/Rtn4/Map1b/Pclo/Kcnma1/Apbb1/Kcnc3/Npy/Kif5a".split("/")
for g in tgenes:
    print(g, df[df.index == g][["module"]].values)

In [None]:
go = go_methods["HotSpot"]
go[go.source.str.contains("GO:0150034")]

In [None]:
plot_gene_names= ['Cbln1',
                  "Dab1",
                  "Gria1",
                  "Baiap2",
                  "Nrsn1",
                  "Ucn"]
for gene in plot_gene_names:
    if stmarker_domains_df[stmarker_domains_df.index== gene].Specific.all():
        s = "Specific"
    else:
        s = "Nonsepcific"
    print(gene, "M{}".format(labels_smaps[gene_indices.index == gene][0]), s)

In [None]:
domain_palette = {str(ind): vega_20_scanpy[ind] for ind in range(n_class)}
plot_gene_indices = find_indices(ann_data, plot_gene_names)
fig, axs = plt.subplots(3, len(plot_gene_names), figsize=(len(plot_gene_names) * 1.35, 4))
for ind, gene_name in enumerate(plot_gene_names):
    temp_palette = {str(ind): "lightgrey" for ind in range(n_class)}
    indices_domain = np.where(stmarker_domains_df[stmarker_domains_df.index == gene_name].filter(regex="Spatial domain"))[1]
    ax = axs[2][ind]
    for ind_domain in indices_domain:
        temp_palette[str(ind_domain)] = domain_palette[str(ind_domain)]
    sc.pl.embedding(ann_data, basis="spatial", color="Consensus clustering", ax=ax, frameon=False, s=5, legend_loc=None, 
                    show=False, palette=temp_palette)
    ax.set_title("")
    ax.set_aspect("equal")
    ax = axs[1][ind]
    ann_data.obs["saliency"] = zsmap[:, plot_gene_indices[ind]]
    sc.pl.embedding(ann_data, basis="spatial", color="saliency", ax=ax, frameon=False, s=5, legend_loc=None, 
                    show=False, color_map="magma", vmin="p0.5", vmax="p99.5")
    ax.set_title("")
    ax.set_aspect("equal")
    ax = axs[0][ind]
    sc.pl.embedding(ann_data, basis="spatial", color=gene_name, ax=ax, frameon=False, s=5, legend_loc=None, 
                    show=False, color_map="viridis", vmin=0, vmax="p99.5")
    ax.set_title("")
    ax.set_aspect("equal")
remove_cbar(fig.axes)
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(os.path.join(fig_dir, "maps_genes.png"), dpi=300)

In [None]:
overlap_df = []
for key in genes_all:
    genes_list = genes_all[key]
    genes_list = set([gene.upper() for gene in genes_list])
    overlap_df.append([key, len(genes_list.intersection(valid_gene_list))])
overlap_df = pd.DataFrame(overlap_df)
overlap_df.columns = ["Method", "Overlap"]
overlap_df

In [None]:
M = 3000
n = 261
N = len(genes_all["STAMarker"])
for method in overlap_df["Method"]:
    k = int(overlap_df[overlap_df.Method == method].Overlap)
    print(method, hypergeom.sf(k, M, n, N))
method_colors = {"STAMarker": "#C13E3F", "HotSpot": "#3B9144", "SPARK-X": "#E1822C", "SpatialDE": "#3375A2"}
fig, ax = plt.subplots(1, 1, figsize=(1.8, 2.25))
sns.barplot(data=overlap_df, x="Method", y="Overlap",palette=method_colors, ax=ax)
ax.set_xlabel("")
ax.set_xticklabels(ax.get_xticklabels(),rotation = 60)
# ax.set_ylabel("Overlap with reference")
ax.set_ylim([0, 120])
plt.tight_layout()
plt.savefig("figures/app5/barplot_overlap_allen_brain.png", dpi=300)