In [None]:
import os
import yaml
from pathlib import Path
import csv

import scanpy as sc
import torch
import numpy as np
import pandas as pd

from scSLAT.model import spatial_match
from scSLAT.metrics import global_score

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
emb0_file = ''
emb1_file = ''
metric_file = ''
matching_file = ''
method = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

# Matching

In [None]:
if 'paste' in metric_file.lower():
    matching = np.loadtxt(matching_file, dtype=int)
else:
    if 'harmony' in metric_file.lower():
        embd0 = adata1.obsm['X_harmony']
        embd1 = adata2.obsm['X_harmony']
    elif 'pca' in metric_file.lower():
        embd0 = adata1.obsm['X_pca']
        embd1 = adata2.obsm['X_pca']
    else:
        embd0 = np.loadtxt(emb0_file, delimiter=',')
        embd1 = np.loadtxt(emb1_file, delimiter=',')
        embd0 = torch.from_numpy(embd0)
        embd1 = torch.from_numpy(embd1)

    best, index, similarity = spatial_match([embd0, embd1], adatas=[adata1,adata2])
    # fitler out the unconfident matches
    filter_list = [row[mask > 0.6].tolist() for row, mask in zip(index, similarity)]
    matching = [ [i,j] for i,j in zip(np.arange(index.shape[0]), filter_list) ]

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
    spot_size = 5
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
    spot_size = 15
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
    spot_size = 5
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'
    spot_size = 5

# Metics

In [None]:
overall_score = global_score([adata1,adata2], matching, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching, topology_meta=topology_meta)

# Save

In [None]:
out_dir = Path(os.path.dirname(metric_file))
sc.settings.figdir = out_dir

In [None]:
# run time
if 'harmony' in metric_file.lower():
    run_time = adata1.uns['harmony_time']
elif 'pca' in metric_file.lower():
    run_time = adata1.uns['pca_time']
else:
    with open(out_dir.parent / method / 'run_time.yaml', 'r') as stream:
        run_time_dic = yaml.safe_load(stream)
    run_time = run_time_dic['run_time']

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['euclidean_dis'] = -1
metric_dic['angle_delta'] = -1
metric_dic['run_time'] = run_time

metric_dic['celltype_macro_f1'] = -1
metric_dic['celltype_micro_f1'] = -1
metric_dic['region_macro_f1'] = -1
metric_dic['region_micro_f1'] = -1
metric_dic['total_macro_f1'] = -1
metric_dic['total_micro_f1'] = -1

metric_dic['match_ratio'] = -1

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

if 'paste' not in metric_file.lower():
    # np.savetxt(matching_file, matching, fmt='%i')
    with open(matching_file, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(matching)