In [None]:
import os
import yaml
from pathlib import Path

import scanpy as sc
import torch
import numpy as np
import pandas as pd

from scSLAT.model import spatial_match
from scSLAT.metrics import global_score
from scSLAT.viz import match_3D_multi

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
emb0_file = ''
emb1_file = ''
metric_file = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

In [None]:
embd0 = pd.read_csv(emb0_file)
embd1 = pd.read_csv(emb1_file)
embd0 = torch.from_numpy(embd0.to_numpy())
embd1 = torch.from_numpy(embd1.to_numpy())

In [None]:
best, index, distance = spatial_match([embd0, embd1],smooth=False)
matching = np.array([range(index.shape[0]), best])

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

# Save

In [None]:
out_dir = Path(os.path.dirname(metric_file))
with open(out_dir / 'run_time.yaml', 'r') as stream:
    run_time_dic = yaml.safe_load(stream)

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['run_time'] = run_time_dic['run_time']

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

# Plot

In [None]:
adata1_df = pd.DataFrame({'index':range(embd0.shape[0]),
                          'x': adata1.obsm['spatial'][:,0],
                          'y': adata1.obsm['spatial'][:,1],
                          'celltype':adata1.obs[biology_meta],
                          'region':adata1.obs[topology_meta]})
adata2_df = pd.DataFrame({'index':range(embd1.shape[0]),
                          'x': adata2.obsm['spatial'][:,0],
                          'y': adata2.obsm['spatial'][:,1],
                          'celltype':adata2.obs[biology_meta],
                          'region':adata2.obs[topology_meta]})
matching = np.array([range(index.shape[0]), best])

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df,matching,meta='celltype',
                             scale_coordinate=True,subsample_size=300,exchange_xy=False)

multi_align.draw_3D(size= [7, 8], line_width=1, point_size=[0.8,0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_celltype.pdf')

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df,matching,meta='region',
                             scale_coordinate=True,subsample_size=300,exchange_xy=False)

multi_align.draw_3D(size= [7, 8], line_width=1, point_size=[0.8,0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_region.pdf')