In [None]:
import numpy as np
import sklearn
import sklearn.metrics.pairwise
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import os

In [None]:
def plot_aligned_misaligned(alignment, labels, adata1, adata2, data='DLPFC', sec='151507_151508', tool='STAligner', save_dir="./"):

    if data == 'DLPFC':
        spot_s = 75

    matched_idx_list = []
    ad1_match_label = []
    ad2_match_label = [2] * alignment.shape[1]


    for i, elem in enumerate(alignment):
        matched_idx_list.append(elem.argmax())
        if labels[i] == labels[elem.argmax() + alignment.shape[0]]:
            ad1_match_label.append(1)
            ad2_match_label[elem.argmax()] = 1
        else:
            ad1_match_label.append(0)
            ad2_match_label[elem.argmax()] = 0

    adata1.obs['matching_spots'] = ad1_match_label
    adata2.obs['matching_spots'] = ad2_match_label

    adata1.obs['matching_spots'] = adata1.obs['matching_spots'].astype('category')
    adata1.obs['matching_spots'] = adata1.obs['matching_spots'].map({1: 'aligned', 0: 'mis-aligned'})

    adata2.obs['matching_spots'] = adata2.obs['matching_spots'].astype('category')
    adata2.obs['matching_spots'] = adata2.obs['matching_spots'].map({1: 'aligned', 0: 'mis-aligned', 2: 'unaligned'})

    fig, ax = plt.subplots(2,1, figsize=(6,18), gridspec_kw={'height_ratios': [1, 1], 'hspace': 0.2})
    sc.pl.spatial(adata1, title=tool, color="matching_spots", spot_size=spot_s, ax=ax[0], show=False)
    sc.pl.spatial(adata2, title=tool, color="matching_spots", spot_size=spot_s, ax=ax[1], show=False)
    
    # Ensure the aspect ratio of the second subplot matches the first
    for axis in ax:
        axis.legend().remove()
    
    plt.tight_layout(pad=3.0)
    fig.text(0.5, 0.03, "Ratio=" + str("{:.2f}".format(alignment.shape[0]/len(set(matched_idx_list)))), 
             fontsize=52, 
             verticalalignment='bottom', 
             horizontalalignment='center')
    
    plt.savefig(os.path.join(save_dir, "SAM" + tool + sec + "viz.pdf"), bbox_inches="tight")
    plt.show()

In [None]:
def get_ratio(alignment, labels):
    matched_idx_list = []
    ad1_match_label = []
    ad2_match_label = [2] * alignment.shape[1]


    for i, elem in enumerate(alignment):
        # print(i, elem)
        # print(elem.argmax(), alignment.shape[0])
        matched_idx_list.append(elem.argmax())
        if labels[i] == labels[elem.argmax() + alignment.shape[0]]:
            ad1_match_label.append(1)
            ad2_match_label[elem.argmax()] = 1
        else:
            ad1_match_label.append(0)
            ad2_match_label[elem.argmax()] = 0
    
    return alignment.shape[0]/len(set(matched_idx_list))

In [None]:
import anndata

"""DLPFC"""
def load_DLPFC(root_dir='./DLPFC12', section_id='151507'):
    # 151507, ..., 151676 12 in total
    ad = sc.read_visium(path=os.path.join(root_dir, section_id), count_file=section_id+'_filtered_feature_bc_matrix.h5')
    ad.var_names_make_unique()

    gt_dir = os.path.join(root_dir, section_id, 'gt')
    gt_df = pd.read_csv(os.path.join(gt_dir, 'tissue_positions_list_GTs.txt'), sep=',', header=None, index_col=0)
    ad.obs['original_clusters'] = gt_df.loc[:, 6]
    keep_bcs = ad.obs.dropna().index
    ad = ad[keep_bcs].copy()
    ad.obs['original_clusters'] = ad.obs['original_clusters'].astype(int).astype(str)
    # print(ad.obs)
    return ad

def load_paste(path_='../samples/paste', sec='151507_151508', data='DLPFC'):
    alignment=np.load(os.path.join(path_, sec, "iter0embedding.npy"))
    labels=np.load(os.path.join(path_, sec, "iter0labels.npy"), allow_pickle=True)

    if data == 'DLPFC':
        ad1 = load_DLPFC(section_id=sec.split('_')[0])
        ad2 = load_DLPFC(section_id=sec.split('_')[1])

    
    return labels, alignment, ad1, ad2

In [None]:
sec_list = ['151507_151508'] 
for sec in sec_list:
    # paste
    labels, alignment3, ad1, ad2 = load_paste(path_='../samples/paste', sec=sec, data='DLPFC')
    plot_aligned_misaligned(alignment3, labels, ad1, ad2, sec=sec, tool='PASTE', data='DLPFC', save_dir="./")
    print(get_ratio(alignment3, labels))