In [None]:
import sys
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import copy
import matplotlib.pyplot as plt
import warnings
sys.path.append("../")
from upsetplot import plot, from_contents
from scanpy.plotting.palettes import vega_20_scanpy
from stamarker.dataset import SpatialDataModule
from stamarker.pipeline import STAMarker, make_spatial_data
from stamarker.utils import parse_args, select_svgs
warnings.filterwarnings("ignore")

# Load data

In [None]:
ann_data = sc.read_h5ad("../dataset/rodriques_slideseq.h5ad")
print(ann_data)
data_module = make_spatial_data(ann_data)
data_module.prepare_data(rad_cutoff=40, n_top_genes=3000, min_counts=20)

# STAMarker training procedure

In [None]:
config = dict()
config.update(parse_args("_params/model.yaml"))
config.update(parse_args("_params/trainer.yaml"))
if not torch.cuda.is_available():
        config["stagate_trainer"]["gpus"] = None
        config["classifier_trainer"]["gpus"] = None

In [None]:
# initialize the `STAMarker` object
stamarker = STAMarker(20, "Mouse_cerebellum_output/", config)

In [None]:
# Train autoencoders
stamarker.train_auto_encoders(data_module)

One the the train 

In [None]:
stamarker.clustering(data_module, "louvain", 0.6)

Consensus cluster the labels into 5 classes

In [None]:
# Perform consensus clustering 
stamarker.consensus_clustering(5)

In [None]:
stamarker.train_classifiers(data_module, 5, consensus_labels_path="consensus_labels.npy")

In [None]:
smap = stamarker.compute_smaps(data_module)

# Visualize spatial domains

In [None]:
consensus_labels = np.load(stamarker.save_dir + "/consensus_labels.npy")
ann_data.obs["Consensus clustering"] = consensus_labels.astype(str)
n_class = np.max(consensus_labels) + 1
print("Num of spatial domains", n_class)

In [None]:
a = 1.5
fig, ax = plt.subplots(1, 1, figsize=(1.45 * a, 1.42 * a))
sc.pl.embedding(ann_data, basis="spatial", color = "Consensus clustering", show=False, ax=ax, s=6, 
                 palette=vega_20_scanpy, frameon=False)
ax.set_title("")
ax.set_aspect("equal")

In [None]:
domain_svg_list = []
for domain_ind in range(5):
    domain_svg_list.append(select_svgs(np.log(1 + smap), domain_ind, consensus_labels, alpha=1.25))
upset_domains_df = from_contents({ f"Spatial domain {ind}": l for ind, l in enumerate(domain_svg_list)})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2.7, 2.5))
df = pd.DataFrame(upset_domains_df.index.to_frame().apply(np.sum, axis=1))
df.columns = ["counts"]
df.index = upset_domains_df["id"]
df_counts = df.groupby("counts")["counts"].agg("count")
ax.bar(df_counts.index, df_counts.values)
ax.set_xticks(df_counts.index)
ax.set_xlabel("Number of spatial domains")
ax.set_ylabel("Number of genes")
plt.tight_layout()