In [16]:
import sys
from typing import List, Mapping, Optional, Union

import torch
import numpy as np
import pandas as pd
import scanpy as sc
import faiss
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import euclidean_distances
from anndata import AnnData

import STAGATE_pyG

In [17]:
def spatial_match(embds:List[torch.Tensor],
                  reorder:Optional[bool]=True,
                  top_n:Optional[int]=20,
                  smooth:Optional[bool]=True,
                  smooth_range:Optional[int]=20,
                  scale_coord:Optional[bool]=True,
                  adatas:Optional[List[AnnData]]=None,
                  verbose:Optional[bool]=False
    )-> List[Union[np.ndarray,torch.Tensor]]:
    r"""
    Use embedding to match cells from different datasets based on cosine similarity
    
    Parameters
    ----------
    embds
        list of embeddings
    reorder
        if reorder embedding by cell numbers
    top_n
        return top n of cosine similarity
    smooth
        if smooth the mapping by Euclid distance
    smooth_range
        use how many candidates to do smooth
    scale_coord
        if scale the coordinate to [0,1]
    adatas
        list of adata object
    verbose
        if print log
    
    Note
    ----------
    Automatically use larger dataset as source
    
    Return
    ----------
    Best matching, Top n matching and cosine similarity matrix of top n  
    
    Note
    ----------
    Use faiss to accelerate, refer https://github.com/facebookresearch/faiss/issues/95
    """
    if reorder and embds[0].shape[0] < embds[1].shape[0]:
        embd0 = embds[1]
        embd1 = embds[0]
        adatas = adatas[::-1] if adatas is not None else None
    else:
        embd0 = embds[0]
        embd1 = embds[1]
    index = faiss.index_factory(embd1.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
    embd0_np = embd0.detach().cpu().numpy() if torch.is_tensor(embd0) else embd0
    embd1_np = embd1.detach().cpu().numpy() if torch.is_tensor(embd1) else embd1
    embd0_np = embd0_np.copy().astype('float32')
    embd1_np = embd1_np.copy().astype('float32')
    faiss.normalize_L2(embd0_np)
    faiss.normalize_L2(embd1_np)
    index.add(embd0_np)
    distance, order = index.search(embd1_np, top_n)
    best = []
    if smooth and adatas != None:
        smooth_range = min(smooth_range, top_n)
        if verbose:
            print('Smoothing mapping, make sure object is in same direction')
        if scale_coord:
            # scale spatial coordinate of every adata to [0,1]
            adata1_coord = adatas[0].obsm['spatial'].copy()
            adata2_coord = adatas[1].obsm['spatial'].copy()
            for i in range(2):
                    adata1_coord[:,i] = (adata1_coord[:,i]-np.min(adata1_coord[:,i]))/(np.max(adata1_coord[:,i])-np.min(adata1_coord[:,i]))
                    adata2_coord[:,i] = (adata2_coord[:,i]-np.min(adata2_coord[:,i]))/(np.max(adata2_coord[:,i])-np.min(adata2_coord[:,i]))
        for query in range(embd1_np.shape[0]):
            ref_list = order[query, :smooth_range]
            dis = euclidean_distances(adata2_coord[query,:].reshape(1, -1), 
                                      adata1_coord[ref_list,:])
            best.append(ref_list[np.argmin(dis)])
    else:
        best = order[:,0]
    return np.array(best), order, distance


In [13]:
sys.path.append("../../../scSLAT/viz")
from multi_dataset import match_3D_celltype, match_3D_multi

In [None]:
rna = sc.read_h5ad('../glue_rna-E11_20um.h5ad')
atac = sc.read_h5ad('../glue_atac-E11_20um.h5ad')

In [None]:
atac.obsm['spatial_bak'] = atac.obsm['spatial'].copy()
atac.obsm['spatial'] = np.array([atac.obsm['spatial_bak'][:,0],-atac.obsm['spatial_bak'][:,1]]).T

In [None]:
STAGATE_pyG.Cal_Spatial_Net(rna, k_cutoff=20, model='KNN')
STAGATE_pyG.Cal_Spatial_Net(atac, k_cutoff=20, model='KNN')
adata = sc.concat([rna, atac])
adata.uns['Spatial_Net'] = pd.concat([rna.uns['Spatial_Net'], atac.uns['Spatial_Net']])

# Failed
STAGATE_pyG only can receive HVGs as input. ([raw code](https://github.com/QIFEIDKN/STAGATE_pyG/blob/main/STAGATE_pyG/Train_STAGATE.py))

In [None]:
# adata = STAGATE_pyG.train_STAGATE(adata, device='cuda')