In [None]:
import time
import yaml

import numpy as np
import pandas as pd
import scanpy as sc
import harmonypy as hm

from scSLAT.utils import global_seed
from scSLAT.model import Cal_Spatial_Net, load_anndatas, run_SLAT, spatial_match, scanpy_workflow
from scSLAT.metrics import global_score, euclidean_dis

In [None]:
# parameters cell
adata1_file = ''
adata2_file = ''
seed = 0

# hyperparameters
feature_dim = 0
theta = 0
lamb = 0

# output
emb0_file = ''
emb1_file = ''
metrics_file = ''

In [None]:
global_seed(seed)

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)
adata1.layers['counts'] = adata1.X.copy()
adata2.layers['counts'] = adata2.X.copy()

# PCA and Harmony

In [None]:
adata_all = adata1.concatenate(adata2)
start = time.time()
adata_all = scanpy_workflow(adata_all, n_comps=feature_dim)
harm = hm.run_harmony(adata_all.obsm['X_pca'], adata_all.obs, 'batch',
                      theta=theta, lamb=lamb, max_iter_harmony=20)
Z = harm.Z_corr.T
end = time.time()
adata_all.obsm['X_harmony'] = Z
run_time = str(end-start)

In [None]:
adata1.obsm['X_harmony'] = Z[:adata1.shape[0],:]
adata2.obsm['X_harmony'] = Z[adata1.shape[0]:,:]

# Calculate metric

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'

In [None]:
embd0 = adata1.obsm['X_harmony']
embd1 = adata2.obsm['X_harmony']
best, index, distance = spatial_match([embd0, embd1])
matching = np.array([range(index.shape[0]), best])

overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

# Save

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['euclidean_dis'] = eud
metric_dic['run_time'] = run_time

with open(metrics_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(emb0_file, adata1.obsm['X_harmony'], delimiter=',')
np.savetxt(emb1_file, adata2.obsm['X_harmony'], delimiter=',')