In [None]:
import os
import yaml
from pathlib import Path

import torch
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, confusion_matrix, ConfusionMatrixDisplay


from scSLAT.model import spatial_match
from scSLAT.metrics import global_score, euclidean_dis, rotation_angle
from scSLAT.viz import match_3D_multi, matching_2d, Sankey

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
emb0_file = ''
emb1_file = ''
metric_file = ''
matching_file = ''
ground_truth = 60

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

# Matching

In [None]:
if 'paste' in metric_file.lower():
    matching = np.loadtxt(matching_file, dtype=int)
else:
    if 'harmony' in metric_file.lower():
        embd0 = adata1.obsm['X_harmony']
        embd1 = adata2.obsm['X_harmony']
    elif 'pca' in metric_file.lower():
        embd0 = adata1.obsm['X_pca']
        embd1 = adata2.obsm['X_pca']
    else:
        embd0 = np.loadtxt(emb0_file, delimiter=',')
        embd1 = np.loadtxt(emb1_file, delimiter=',')
        embd0 = torch.from_numpy(embd0)
        embd1 = torch.from_numpy(embd1)

    best, index, distance = spatial_match([embd0, embd1])
    matching = np.array([range(index.shape[0]), best])

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
    spot_size = 5
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
    spot_size = 15
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
    spot_size = 5
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'
    spot_size = 5

In [None]:
out_dir = Path(os.path.dirname(metric_file))
sc.settings.figdir = out_dir

# Metics

In [None]:
# angle
data = np.ones(matching.shape[1])
matching_sparse = sp.coo_matrix((data, (matching[1], matching[0])), shape=(adata1.n_obs, adata2.n_obs))
angle = rotation_angle(adata1.obsm['spatial'], adata2.obsm['spatial'], matching_sparse.toarray(), ground_truth=ground_truth)

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

## F1

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    adata2.obs[biology_meta] = 'celltype_' + adata2.obs[biology_meta].astype('str')
    adata1.obs[biology_meta] = 'celltype_' + adata1.obs[biology_meta].astype('str')

adata2.obs['target_celltype'] = adata1.obs.iloc[matching[1,:],:][biology_meta].to_list()
adata2.obs['target_region'] = adata1.obs.iloc[matching[1,:],:][topology_meta].to_list()
adata2.obs['target_celltype_region'] = adata2.obs['target_celltype'].astype('str') + '_' + adata2.obs['target_region'].astype('str')
adata2.obs['celltype_region'] = adata2.obs[biology_meta].astype('str') + '_' + adata2.obs[topology_meta].astype('str')

In [None]:
celltype_macro_f1 = f1_score(adata2.obs[biology_meta], adata2.obs['target_celltype'], average='macro')
celltype_micro_f1 = f1_score(adata2.obs[biology_meta], adata2.obs['target_celltype'], average='micro')

region_macro_f1 = f1_score(adata2.obs[topology_meta], adata2.obs['target_region'], average='macro')
region_micro_f1 = f1_score(adata2.obs[topology_meta], adata2.obs['target_region'], average='micro')

total_macro_f1 = f1_score(adata2.obs['celltype_region'], adata2.obs['target_celltype_region'], average='macro')
total_micro_f1 = f1_score(adata2.obs['celltype_region'], adata2.obs['target_celltype_region'], average='micro')

## Confusion Matrix

In [None]:
celltype_label = adata2.obs[biology_meta].unique().tolist()
region_label = adata2.obs[topology_meta].unique().tolist()
celltype_region_label = adata2.obs['celltype_region'].unique().tolist()

In [None]:
plt.figure(figsize=(len(celltype_region_label) / 2, len(celltype_region_label) /2))
cm = confusion_matrix(adata2.obs['celltype_region'], adata2.obs['target_celltype_region'], labels=celltype_region_label)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=celltype_region_label)
disp.plot(cmap='Reds', xticks_rotation='vertical', ax=plt.gca())
plt.savefig(out_dir / 'joint_confusing_matrix.png', dpi=300, bbox_inches='tight')

In [None]:
plt.figure(figsize=(len(celltype_label) / 2, len(celltype_label) /2))
cm = confusion_matrix(adata2.obs[biology_meta], adata2.obs['target_celltype'], labels=celltype_label)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=celltype_label)
disp.plot(cmap='Reds', xticks_rotation='vertical', ax=plt.gca())
plt.savefig(out_dir / 'celltype_confusing_matrix.png', dpi=300, bbox_inches='tight')

In [None]:
plt.figure(figsize=(len(region_label) / 2, len(region_label) /2))
cm = confusion_matrix(adata2.obs[topology_meta], adata2.obs['target_region'], labels=region_label)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=region_label)
disp.plot(cmap='Reds', xticks_rotation='vertical', ax=plt.gca())
plt.savefig(out_dir / 'region_confusing_matrix.png', dpi=300, bbox_inches='tight')

## Ground truth (perturb)

In [None]:
if 'perturb' in matching_file:
    match_ratio =  (matching[0] == matching[1]).sum() / len(matching[0])
else:
    match_ratio = -1

# Save

In [None]:
# run time
if 'harmony' in metric_file.lower():
    run_time = adata1.uns['harmony_time']
elif 'pca' in metric_file.lower():
    run_time = adata1.uns['pca_time']
else:
    with open(out_dir / 'run_time.yaml', 'r') as stream:
        run_time_dic = yaml.safe_load(stream)
    run_time = run_time_dic['run_time']

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['euclidean_dis'] = eud
metric_dic['angle_delta'] = float(angle)
metric_dic['run_time'] = run_time

metric_dic['celltype_macro_f1'] = float(celltype_macro_f1)
metric_dic['celltype_micro_f1'] = float(celltype_micro_f1)
metric_dic['region_macro_f1'] = float(region_macro_f1)
metric_dic['region_micro_f1'] = float(region_micro_f1)
metric_dic['total_macro_f1'] = float(total_macro_f1)
metric_dic['total_micro_f1'] = float(total_micro_f1)

metric_dic['match_ratio'] = float(match_ratio)

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

if 'paste' not in metric_file.lower():
    np.savetxt(matching_file, matching, fmt='%i')

# Plot

In [None]:
adata1_df = pd.DataFrame({'index':range(adata1.shape[0]),
                          'x': adata1.obsm['spatial'][:,0],
                          'y': adata1.obsm['spatial'][:,1],
                          'celltype':adata1.obs[biology_meta],
                          'region':adata1.obs[topology_meta]})
adata2_df = pd.DataFrame({'index':range(adata2.shape[0]),
                          'x': adata2.obsm['spatial'][:,0],
                          'y': adata2.obsm['spatial'][:,1],
                          'celltype':adata2.obs[biology_meta],
                          'region':adata2.obs[topology_meta]})

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='celltype',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], 
                    hide_axis=True, show_error=True, save=out_dir / 'match_by_celltype.pdf')

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='region',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], 
                    hide_axis=True, show_error=True, save=out_dir / 'match_by_region.pdf')

In [None]:
# 2D matching plot
matching_2d(matching, adata1, adata2, biology_meta, topology_meta, spot_size, save='matching_2d.pdf')

In [None]:
# Sankey plot
adata2.obs['target_celltype'] = adata1.obs.iloc[matching[1,:],:][biology_meta].to_list()
adata2.obs['target_region'] = adata1.obs.iloc[matching[1,:],:][topology_meta].to_list()
## by cell type
matching_table = adata2.obs.groupby([biology_meta,'target_celltype']).size().unstack(fill_value=0)
matching_table.index = adata2.obs[biology_meta].unique()
matching_table.columns = adata2.obs['target_celltype'].unique()
print(matching_table)

Sankey(matching_table, prefix=['Slide1', 'Slide2'], save_name=str(out_dir/'celltype_sankey'),
       format='svg', width=1000, height=1000)

## by region
matching_table = adata2.obs.groupby([topology_meta,'target_region']).size().unstack(fill_value=0)
matching_table.index = adata2.obs[topology_meta].unique()
matching_table.columns = adata2.obs['target_region'].unique()
print(matching_table)

Sankey(matching_table, prefix=['Slide1', 'Slide2'], save_name=str(out_dir/'region_sankey'),
       format='svg', width=1000, height=1000)


# Reverse matching

In [None]:
if 'paste' in metric_file.lower():
    matching_2d(matching, adata1, adata2, biology_meta, topology_meta, spot_size, save='matching_rev_2d.pdf')
else:
    best_rev, index_rev, _ = spatial_match([embd1, embd0], adatas=[adata2, adata1], reorder=False)
    matching_rev = np.array([range(index_rev.shape[0]), best_rev])
    matching_2d(matching_rev, adata2, adata1, biology_meta, topology_meta, spot_size, save='matching_rev_2d.pdf')