In [None]:
import yaml
import os
import time
from pathlib import Path

import scanpy as sc
import torch
import numpy as np
import pandas as pd

from scSLAT.model import load_anndatas, run_SLAT, Cal_Spatial_Net, spatial_match
from scSLAT.metrics import global_score
from scSLAT.viz import match_3D_multi

In [None]:
sc.set_figure_params(dpi_save=200, dpi=150)

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
metric_file = ''
emb0_file = ''
emb1_file = ''
matching_file = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

# Parameter for dataset

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
    LGCN_layer = 2
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
    LGCN_layer = 2
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
    LGCN_layer = 1
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'
    LGCN_layer = 2

# run SLAT

In [None]:
start = time.time()
Cal_Spatial_Net(adata1, k_cutoff=20, model='KNN')
Cal_Spatial_Net(adata2, k_cutoff=20, model='KNN')
edges, features = load_anndatas([adata1, adata2], feature='PCA')
embd0, embd1, time1 = run_SLAT(features, edges, 6, LGCN_layer=LGCN_layer)
print('Runtime: ' + str(time.time() - start))
run_time = str(time.time() - start)

In [None]:
adata1.obsm['X_slat'] = embd0.cpu().detach().numpy()
adata2.obsm['X_slat'] = embd1.cpu().detach().numpy()

# Metric

In [None]:
embd0 = adata1.obsm['X_slat']
embd1 = adata2.obsm['X_slat']
best, index, distance = spatial_match([embd0, embd1], adatas=[adata1,adata2])
matching = np.array([range(index.shape[0]), best])

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

# Save

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['run_time'] = run_time

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(emb0_file, adata1.obsm['X_slat'], delimiter=',')
np.savetxt(emb1_file, adata2.obsm['X_slat'], delimiter=',')
np.savetxt(matching_file, matching, fmt='%i')

# Plot

In [None]:
out_dir = Path(os.path.dirname(metric_file))

In [None]:
# adata_all = adata1.concatenate(adata2)
# sc.pp.neighbors(adata_all, metric="cosine", use_rep='X_slat')
# sc.tl.umap(adata_all)
# sc.pl.umap(adata_all, color=biology_meta, save=out_dir / 'biology.pdf')
# sc.pl.umap(adata_all, color=topology_meta, save=out_dir / 'topology.pdf')
# sc.pl.umap(adata_all, color="batch", save=out_dir / 'batch.pdf')

In [None]:
adata1_df = pd.DataFrame({'index':range(embd0.shape[0]),
                          'x': adata1.obsm['spatial'][:,0],
                          'y': adata1.obsm['spatial'][:,1],
                          'celltype':adata1.obs[biology_meta],
                          'region':adata1.obs[topology_meta]})
adata2_df = pd.DataFrame({'index':range(embd1.shape[0]),
                          'x': adata2.obsm['spatial'][:,0],
                          'y': adata2.obsm['spatial'][:,1],
                          'celltype':adata2.obs[biology_meta],
                          'region':adata2.obs[topology_meta]})
matching = np.array([range(index.shape[0]), best])

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df, matching,meta='celltype',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8, 0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_celltype.pdf')

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df,matching,meta='region',
                             scale_coordinate=True,subsample_size=300,exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8, 0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_region.pdf')