In [5]:
import pickle
import os
import itertools
import time
import yaml
from pathlib import Path

import torch_geometric
import scanpy as sc
import torch
import numpy as np

from scSLAT.utils import global_seed
from scSLAT.model import Cal_Spatial_Net, load_anndatas, run_LGCN, spatial_match
from scSLAT.metrics import global_score

In [None]:
# parameter cell
# input
adata1_file = ''
adata2_file = ''

# graph hyperparameter
k_cutoff = -1

# embed hyperparameter
feature_type = ''

# model hyperparameter
LGCN_layer = 0

# align hyperparameter
smooth = -1

# seed
seed = -1

# output
emb0_file = ''
emb1_file = ''
metrics_file = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

# Run model

In [None]:
global_seed(seed)
start = time.time()
Cal_Spatial_Net(adata1, k_cutoff=50, model='KNN')
Cal_Spatial_Net(adata2, k_cutoff=50, model='KNN')
edges, features = load_anndatas([adata1, adata2], feature=feature_type)
embd0, embd1, time1 = run_LGCN(features, edges,
                                 LGCN_layer=LGCN_layer,)
print('Runtime: ' + str(time.time() - start))
run_time = str(time.time() - start)

# Calculate metric

In [None]:
adata1.obsm['X_slat'] = embd0.cpu().detach().numpy()
adata2.obsm['X_slat'] = embd1.cpu().detach().numpy()

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'

In [None]:
embd0 = adata1.obsm['X_slat']
embd1 = adata2.obsm['X_slat']
best, index, distance = spatial_match([embd0, embd1], adatas=[adata1,adata2], top_n=smooth)
matching = np.array([range(index.shape[0]), best])

overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

# Save

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['run_time'] = run_time

with open(metrics_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(emb0_file, adata1.obsm['X_slat'], delimiter=',')
np.savetxt(emb1_file, adata2.obsm['X_slat'], delimiter=',')

# Plot

In [None]:
# adata_all = adata1.concatenate(adata2)
# out_dir = Path(os.path.dirname(metrics_file))
# sc.pp.neighbors(adata_all, metric="cosine", use_rep='X_slat')
# sc.tl.umap(adata_all)
# sc.pl.umap(adata_all, color=biology_meta,save=out_dir / 'biology.pdf')
# sc.pl.umap(adata_all, color=topology_meta,save=out_dir / 'topology.pdf')
# sc.pl.umap(adata_all, color="batch",save=out_dir / 'batch.pdf')