In [None]:
import math
import time
import pandas as pd
import numpy as np
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import style
import paste as pst
import ot
import seaborn
from anndata import AnnData
import yaml
from pathlib import Path
import os
import random
import torch

from typing import List, Mapping, Optional, Union
import random

In [None]:
sc.set_figure_params(dpi_save=200,dpi=150)

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
metrics_file = ''
matching_file = ''

# hyperparameters
alpha = 0.0

# emb0_file = '' # PASTE do not have emb 
# emb1_file = ''

In [None]:
def global_seed(seed: int):
    r"""
    Set seed
    
    Parameters
    ----------
    seed 
        int
    """
    seed = seed if seed != -1 else torch.seed()
    if seed > 2**32 - 1:
        seed = seed >> 32

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f"Global seed set to {seed}.")

def get_metric(adatas:List[AnnData],
                matching:np.ndarray,
                biology_meta:Optional[str]='',
                topology_meta:Optional[str]=''
    ) -> float:
    for adata in adatas:
        assert biology_meta in adata.obs.columns or topology_meta in adata.obs.columns
        if biology_meta not in adata.obs:
            adata.obs[biology_meta] = 'Unknown'
            print(f"Warning!,biology_meta not in adata.obs ")
        elif topology_meta not in adata.obs:
            adata.obs[topology_meta] = 'Unknown'
            print(f"Warning!,topology_meta not in adata.obs ")
        adata.obs['global_meta'] = adata.obs[biology_meta].astype(str) + '-' + adata.obs[topology_meta].astype(str)
    count = 0
    for i in range(matching.shape[0]): # query dataset
        query_meta = adatas[1].obs.iloc[i].loc['global_meta']
        ref_meta = adatas[0].obs.iloc[matching[i,1]].loc['global_meta']
        count = count + 1 if query_meta == ref_meta else count
    score = count/adatas[1].shape[0]
    
    return score

def scanpy_workflow(adata:AnnData,
                    n_top_genes:Optional[int]=2500,
                    n_comps:Optional[int]=50
    ) -> AnnData:
    r"""
    Scanpy workflow using Seurat HVG
    
    Parameters
    ----------
    adata
        adata
    n_top_genes
        n top genes
    n_comps
        n PCA components
    """
    if 'counts' not in adata.layers.keys():
        adata.layers["counts"] = adata.X.copy()
    if "highly_variable" not in adata.var_keys():
        sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor="seurat_v3")
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.scale(adata)
    sc.tl.pca(adata, n_comps=n_comps, svd_solver="auto")
    return adata.copy()


class match_3D_multi():
    r"""
    Plot the mapping result between 2 datasets
    
    Parameters
    ---------
    dataset_A
        pandas dataframe which contain ['index','x','y'], reference dataset
    dataset_B
        pandas dataframe which contain ['index','x','y'], target dataset
    matching
        matching results
    meta
        dataframe colname of meta, such as celltype
    expr
        dataframe colname of gene expr
    subsample_size
        subsample size of matches
    reliability
        match score (cosine similarity score)
    scale_coordinate
        if scale coordinate via (:math:`data - np.min(data)) / (np.max(data) - np.min(data))`)
    rotate
        how to rotate the slides (force scale_coordinate), such as ['x','y'], means dataset0 rotate on x axes
        and dataset1 rotate on y axes
    change_xy
        exchange x and y on dataset_B

    Note
    ----------
    dataset_A and dataset_B can in different length
        
    """
    def __init__(self,dataset_A:pd.DataFrame,
                 dataset_B: pd.DataFrame,
                 matching: np.ndarray,
                 meta: Optional[str] = None,
                 expr: Optional[str] = None,
                 subsample_size: Optional[int]=300,
                 reliability: Optional[np.ndarray]=None,
                 scale_coordinate: Optional[bool]=True,
                 rotate: Optional[List[str]]=None,
                 exchange_xy: Optional[bool]=False
        ) -> None:
        self.dataset_A = dataset_A.copy()
        self.dataset_B = dataset_B.copy()
        self.meta = meta
        self.matching= matching
        self.conf = reliability
        scale_coordinate = True if rotate != None else scale_coordinate
        
        assert all(item in dataset_A.columns.values for item in ['index','x','y'])
        assert all(item in dataset_B.columns.values for item in ['index','x','y'])
        
        if meta:
            self.celltypes = set(self.dataset_A[meta].append(self.dataset_B[meta]))
            set1 = list(set(self.dataset_A[meta]))
            set2 = list(set(self.dataset_B[meta]))
            overlap = [x for x in set2 if x in set1]
            print(f"dataset1: {len(set1)} cell types; dataset2: {len(set2)} cell types; \n\
                    All :{len(self.celltypes)} celltypes; Overlap: {len(overlap)} cell types \n\
                    Not overlap :[{[y for y in (set1+set2) if y not in overlap]}]"
                    )
        self.expr = expr if expr else False
            
        if scale_coordinate:
            for i, dataset in enumerate([self.dataset_A, self.dataset_B]):
                for axis in ['x','y']:
                    dataset[axis] = (dataset[axis] - np.min(dataset[axis])) / (np.max(dataset[axis])- np.min(dataset[axis]))
                    if rotate == None:
                        pass
                    elif axis in rotate[i]:
                        dataset[axis] = 1 - dataset[axis]
        if exchange_xy:
            self.dataset_B[['x','y']] = self.dataset_B[['y','x']]

        subsample_size = subsample_size if matching.shape[1] > subsample_size else matching.shape[1]
        print(f'Subsample {subsample_size} cell pairs from {matching.shape[1]}')
        self.matching = matching[:,np.random.choice(matching.shape[1],subsample_size, replace=False)]
            
        self.datasets = [self.dataset_A, self.dataset_B]
    
    def draw_3D(self,
                size: Optional[List[int]]=[10,10],
                point_size: Optional[List[int]]=[0.1,0.1],
                line_width: Optional[float]=0.3,
                line_color:Optional[str]='grey',
                line_alpha: Optional[float]=0.7,
                hide_axis: Optional[bool]=False,
                show_error: Optional[bool]=True,
                cmap: Optional[bool]='Reds',
                save:Optional[str]=None
        ) -> None:
        r"""
        Draw 3D picture of two datasets
        
        Parameters:
        ----------
        size
            plt figure size
        point_size
            point size of every dataset
        line_width
            pair line width
        line_color
            pair line color
        line_alpha
            pair line alpha
        hide_axis
            if hide axis
        show_error
            if show error celltype mapping with different color
        cmap
            color map when vis expr
        save
            save file path
        """
        show_error = show_error if self.meta else False
        fig = plt.figure(figsize=(size[0],size[1]))
        ax = fig.add_subplot(111, projection='3d')
        # color by different cell types
        if self.meta:
            color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(len(self.celltypes))]
            c_map = {}
            for i, celltype in enumerate(self.celltypes):
                c_map[celltype] = color[i]
            if self.expr:
                c_map = cmap
                # expr_concat = pd.concat(self.datasets)[self.expr].to_numpy()
                # norm = plt.Normalize(expr_concat.min(), expr_concat.max())
            for i, dataset in enumerate(self.datasets):
                if self.expr:
                    norm = plt.Normalize(dataset[self.expr].to_numpy().min(), dataset[self.expr].to_numpy().max())
                for cell_type in self.celltypes:
                    slice = dataset[dataset[self.meta] == cell_type]
                    xs = slice['x']
                    ys = slice['y']
                    zs = i
                    if self.expr:
                        ax.scatter(xs,ys,zs,s=point_size[i],c=slice[self.expr],cmap=c_map,norm=norm)
                    else:
                        ax.scatter(xs,ys,zs,s=point_size[i],c=c_map[cell_type])
                    
        # plot different point layers
        else:
            for i, dataset in enumerate(self.datasets):
                xs = dataset['x']
                ys = dataset['y']
                zs = i
                ax.scatter(xs,ys,zs,s=point_size[i])
        
        # plot line
        self.draw_lines(ax,show_error,line_color,line_width,line_alpha)
        if hide_axis:
            plt.axis('off')
        if save != None:
            plt.savefig(save)
        plt.show()

        
    def draw_lines(self,ax,show_error,default_color,line_width=0.3,line_alpha=0.7) -> None:
        r"""
        Draw lines between paired cells in two datasets
        """
        for i in range(self.matching.shape[1]):
            pair = self.matching[:,i]
            default_color = default_color
            if self.meta != None:
                if self.dataset_B.loc[self.dataset_B['index']==pair[0], self.meta].astype(str).values ==\
                    self.dataset_A.loc[self.dataset_A['index']==pair[1], self.meta].astype(str).values:
                    color = '#ade8f4' # blue
                else:
                    color = '#ffafcc'  # red
                if self.conf:
                    if color == '#ade8f4' and not self.conf[i]: # low reliability but right
                        color = '#588157' # green
                    elif color == '#ffafcc' and self.conf[i]: # high reliability but error
                        color = '#ffb703' # yellow
                
            point0 = np.append(self.dataset_A[self.dataset_A['index']==pair[1]][['x','y']], 0)
            point1 = np.append(self.dataset_B[self.dataset_B['index']==pair[0]][['x','y']], 1)
            coord = np.row_stack((point0,point1))
            color = color if show_error else default_color
            ax.plot(coord[:,0], coord[:,1], coord[:,2], color=color, linestyle="dashed", linewidth=line_width, alpha=line_alpha)


def euclidean_dis(adata1:AnnData,
                  adata2:AnnData,
                  matching:np.ndarray,
                  spatial_key:Optional[str]='spatial'
    ) -> np.ndarray:
    r"""
    Calculate euclidean distance between two datasets with ground truth
    
    Parameters
    ----------
    adata1
        adata1 with spatial
    adata2
        adata2 with spatial
    matching
        matching result
    spatial_key
        key of spatial data in adata.obsm
    """
    # reindex adata1 and adata2 by matching then calculate the pairwise euclidean distance
    if abs(adata1.obsm[spatial_key].max()) > 1 or abs(adata1.obsm[spatial_key].min()) > 1:
        adata1.obsm['scale_spatial'] = adata1.obsm[spatial_key]/adata1.obsm[spatial_key].max()
    if abs(adata2.obsm[spatial_key].max()) > 1 or abs(adata2.obsm[spatial_key].min()) > 1:
        adata2.obsm['scale_spatial'] = adata2.obsm[spatial_key]/adata2.obsm[spatial_key].max()
    spatial_key = 'scale_spatial'
    coord1 = adata1.obsm[spatial_key][matching[1,:]]
    coord2 = adata2.obsm[spatial_key]
    distance = np.sqrt((coord1[:,0] - coord2[:,0])**2+(coord1[:,1] - coord2[:,1])**2)
    return float(distance.sum()/distance.shape[0])


In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

In [None]:
start = time.time()
pi0 = pst.match_spots_using_spatial_heuristic(adata1.obsm['spatial'], adata2.obsm['spatial'], use_ot=True)
pi12 = pst.pairwise_align(adata1, adata2,
                          use_gpu=torch.cuda.is_available(),
                          backend=ot.backend.TorchBackend(), 
                          alpha=alpha, G_init=pi0, norm=True, verbose=True)
print('Runtime: ' + str(time.time() - start))
run_time = str(time.time() - start)

In [None]:
result = pd.DataFrame(pi12)
result.shape

In [None]:
matching_index = np.argmax(result.to_numpy(),axis=0)
matching = np.array([np.arange(result.shape[1]),matching_index])

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'

In [None]:
metric = get_metric([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = get_metric([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = get_metric([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

# Save metric

In [None]:
out_dir = Path(os.path.dirname(metrics_file))

In [None]:
# save
metric_dic = {}
metric_dic['global_score'] = metric
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['euclidean_dis'] = eud
metric_dic['run_time'] = run_time

with open(metrics_file, "w") as f:
    yaml.dump(metric_dic, f)

In [None]:
np.savetxt(matching_file, matching, fmt='%i')

# Visualization 

In [None]:
# sc.pl.spatial(adata1, color = [biology_meta, topology_meta], spot_size=5)
# sc.pl.spatial(adata2, color = [biology_meta, topology_meta], spot_size=5)

In [None]:
# pis = [pi12]
# slices = [adata1, adata2]

# new_slices = pst.stack_slices_pairwise(slices, pis)

In [None]:
# slice_colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3']

# plt.figure(figsize=(7,7))
# for i in range(len(new_slices)):
#     pst.plot_slice(new_slices[i],slice_colors[i],s=50)
# plt.legend(handles=[mpatches.Patch(color=slice_colors[0], label='1'),mpatches.Patch(color=slice_colors[1], label='2')])
# plt.gca().invert_yaxis()
# plt.axis('off')
# plt.savefig(out_dir / 'PASTE.pdf', dpi=200, format ='pdf')
# plt.show()

In [None]:
# adata1_df = pd.DataFrame({'index':range(adata1.shape[0]),
#                           'x': adata1.obsm['spatial'][:,0],
#                           'y': adata1.obsm['spatial'][:,1],
#                           'celltype':adata1.obs[biology_meta],
#                           'region':adata1.obs[topology_meta]})
# adata2_df = pd.DataFrame({'index':range(adata2.shape[0]),
#                           'x': adata2.obsm['spatial'][:,0],
#                           'y': adata2.obsm['spatial'][:,1],
#                           'celltype':adata2.obs[biology_meta],
#                           'region':adata2.obs[topology_meta]})

In [None]:
# multi_align = match_3D_multi(adata1_df, adata2_df,matching,meta='celltype',
#                              scale_coordinate=True,subsample_size=300,exchange_xy=False)

# multi_align.draw_3D(size=[7, 8], line_width =1, point_size=[0.8,0.8], hide_axis=True, show_error=True, save=out_dir / 'match_by_celltype.pdf')

In [None]:
# multi_align = match_3D_multi(adata1_df, adata2_df,matching,meta='region',
#                              scale_coordinate=True,subsample_size=300,exchange_xy=False)

# multi_align.draw_3D(size=[7, 8],line_width=1, point_size=[0.8,0.8], hide_axis=True, show_error=True, save=out_dir / 'match_by_region.pdf')