In [None]:
import os
import torch
import pandas as pd
import scanpy as sc
from sklearn import metrics
import multiprocessing as mp
import numpy as np
import squidpy as sq
import scanpy as sc
from SpaceFlow import SpaceFlow
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP, load_embryo
from scipy.spatial import *
from sklearn.preprocessing import *

from sklearn.metrics import *
from scipy.spatial.distance import *

In [None]:
def res_search_fixed_clus(adata, fixed_clus_count, increment=0.1):
    '''
        arg1(adata)[AnnData matrix]
        arg2(fixed_clus_count)[int]
        
        return:
            resolution[int]
    '''
    for res in sorted(list(np.arange(0.2, 2.5, increment)), reverse=True):
        sc.tl.leiden(adata, random_state=0, resolution=res)
        count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
        if count_unique_leiden == fixed_clus_count:
            break
    return res


def fx_1NN(i,location_in):
    location_in = np.array(location_in)
    dist_array = distance_matrix(location_in[i,:][None,:],location_in)[0,:]
    dist_array[i] = np.inf
    return np.min(dist_array)


def fx_kNN(i,location_in,k,cluster_in):

    location_in = np.array(location_in)
    cluster_in = np.array(cluster_in)


    dist_array = distance_matrix(location_in[i,:][None,:],location_in)[0,:]
    dist_array[i] = np.inf
    ind = np.argsort(dist_array)[:k]
    cluster_use = np.array(cluster_in)
    if np.sum(cluster_use[ind]!=cluster_in[i])>(k/2):
        return 1
    else:
        return 0

def _compute_CHAOS(clusterlabel, location):

    clusterlabel = np.array(clusterlabel)
    location = np.array(location)
    matched_location = StandardScaler().fit_transform(location)

    clusterlabel_unique = np.unique(clusterlabel)
    dist_val = np.zeros(len(clusterlabel_unique))
    count = 0
    for k in clusterlabel_unique:
        location_cluster = matched_location[clusterlabel==k,:]
        if len(location_cluster)<=2:
            continue
        n_location_cluster = len(location_cluster)
        results = [fx_1NN(i,location_cluster) for i in range(n_location_cluster)]
        dist_val[count] = np.sum(results)
        count = count + 1

    return np.sum(dist_val)/len(clusterlabel)

def _compute_PAS(clusterlabel,location):
        
    clusterlabel = np.array(clusterlabel)
    location = np.array(location)
    matched_location = location
    results = [fx_kNN(i,matched_location,k=10,cluster_in=clusterlabel) for i in range(matched_location.shape[0])]
    return np.sum(results)/len(clusterlabel)


def compute_CHAOS(adata,pred_key,spatial_key='spatial'):
    return _compute_CHAOS(adata.obs[pred_key],adata.obsm[spatial_key])

def compute_PAS(adata,pred_key,spatial_key='spatial'):
    return _compute_PAS(adata.obs[pred_key],adata.obsm[spatial_key])

def compute_ASW(adata,pred_key,spatial_key='spatial'):
    d = squareform(pdist(adata.obsm[spatial_key]))
    return silhouette_score(X=d,labels=adata.obs[pred_key],metric='precomputed')

### DLPFC

In [None]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]

for setting_combi in setting_combinations:
    n_clusters = setting_combi[0]  # 7

    dataset = setting_combi[1]  # '151673'

    dir_ = '../benchmarking_data/DLPFC12'
    adata = load_DLPFC(root_dir=dir_, section_id=dataset)
    sc.pp.filter_genes(adata, min_cells=3)
    adata.var_names_make_unique()
    sf = SpaceFlow.SpaceFlow(adata=adata)

    #preprocess
    sf.preprocessing_data(n_top_genes=3000)
    
    ari_list = []
    nmi_list = []
    ami_list = []
    hm_list = []
    time_list = []
    chaos_list = []
    pas_list = []
    asw_list = []
    for iter in range(20):
        import tracemalloc
        import time
    
        tracemalloc.start()  
        start_time=time.time()

        sf.train(spatial_regularization_strength=0.1, 
            embedding_save_filepath="./results_0424/DLPFC/"+dataset+"_"+str(iter)+"embedding.tsv",
            z_dim=50, 
            lr=1e-3, 
            epochs=1000, 
            max_patience=50, 
            min_stop=100, 
            random_seed=42, 
            gpu=1, 
            regularization_acceleration=True, 
            edge_subset_sz=1000000)
        
        # n_clusters=7
        sc.pp.neighbors(adata, n_neighbors=50)
        eval_resolution = res_search_fixed_clus(adata, n_clusters)

        sf.segmentation(domain_label_save_filepath="./results_0424/DLPFC/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1), 
                    n_neighbors=50, 
                    resolution=eval_resolution)
        
        pred=pd.read_csv("./results_0424/DLPFC/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),header=None)
        pred_list=pred.iloc[:,0].to_list()
        
        adata.obs['pred_{}'.format(iter+1)] = np.array(pred_list)

        obs_df = adata.obs.dropna()
        ari = metrics.adjusted_rand_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        nmi = metrics.normalized_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        # print("AMI")
        ami = metrics.adjusted_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        # print("homogeneity")
        homogeneity = metrics.homogeneity_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])

        chaos = compute_CHAOS(adata, 'pred_{}'.format(iter+1))
        pas = compute_PAS(adata, 'pred_{}'.format(iter+1))
        asw = compute_ASW(adata, 'pred_{}'.format(iter+1))


        ari_list.append(ari)
        nmi_list.append(nmi)
        ami_list.append(ami)
        hm_list.append(homogeneity)
        chaos_list.append(chaos)
        pas_list.append(pas)
        asw_list.append(asw)

        end_time=time.time()
        during=end_time-start_time

        size, peak = tracemalloc.get_traced_memory()

        tracemalloc.stop()
        
        # memory[i]=peak /1024/1024
        time_list.append(during)
        
    
        # print('memory blocks peak:{:>10.4f} MB'.format(memory[i]))
        print('time: {:.4f} s'.format(during))
        print('ARI:{}'.format(ari))
        print('NMI:{}'.format(nmi))
        print('AMI:{}'.format(ami))
        print('Homogeneity:{}'.format(homogeneity))
        print('chaos:{}'.format(chaos))
        print('pas:{}'.format(pas))
        print('asw:{}'.format(asw))

### MHypo

In [None]:
"""MHypo"""
setting_combinations = [[8, '-0.04'], [8, '-0.09'], [8, '-0.14'], [8, '-0.19'], [8, '-0.24']]
for setting_combi in setting_combinations:
    n_clusters = setting_combi[0]  # 7

    dataset = setting_combi[1]  # '151673'

    dir_ = '../benchmarking_data/mHypothalamus'
    adata = load_mHypothalamus(root_dir=dir_, section_id=dataset)
    sc.pp.filter_genes(adata, min_cells=3)
    adata.var_names_make_unique()
    sf = SpaceFlow.SpaceFlow(adata=adata)

    #preprocess
    sf.preprocessing_data(n_top_genes=3000)
    
    ari_list = []
    nmi_list = []
    ami_list = []
    hm_list = []
    time_list = []
    chaos_list = []
    pas_list = []
    asw_list = []
    for iter in range(20):
        import tracemalloc
        import time
    
        tracemalloc.start()  
        start_time=time.time()

        sf.train(spatial_regularization_strength=0.1, 
            embedding_save_filepath="./results_0424/mHypo/"+dataset+"_"+str(iter)+"embedding.tsv",
            z_dim=50, 
            lr=1e-3, 
            epochs=1000, 
            max_patience=50, 
            min_stop=100, 
            random_seed=42, 
            gpu=1, 
            regularization_acceleration=True, 
            edge_subset_sz=1000000)
        
        # n_clusters=7
        sc.pp.neighbors(adata, n_neighbors=50)
        eval_resolution = res_search_fixed_clus(adata, n_clusters)

        sf.segmentation(domain_label_save_filepath="./results_0424/mHypo/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1), 
                    n_neighbors=50, 
                    resolution=eval_resolution)
        
        pred=pd.read_csv("./results_0424/mHypo/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),header=None)
        pred_list=pred.iloc[:,0].to_list()
        
        adata.obs['pred_{}'.format(iter+1)] = np.array(pred_list)

        obs_df = adata.obs.dropna()
        ari = metrics.adjusted_rand_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        nmi = metrics.normalized_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        # print("AMI")
        ami = metrics.adjusted_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
        # print("homogeneity")
        homogeneity = metrics.homogeneity_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])

        chaos = compute_CHAOS(adata, 'pred_{}'.format(iter+1))
        pas = compute_PAS(adata, 'pred_{}'.format(iter+1))
        asw = compute_ASW(adata, 'pred_{}'.format(iter+1))


        ari_list.append(ari)
        nmi_list.append(nmi)
        ami_list.append(ami)
        hm_list.append(homogeneity)
        chaos_list.append(chaos)
        pas_list.append(pas)
        asw_list.append(asw)

        end_time=time.time()
        during=end_time-start_time

        size, peak = tracemalloc.get_traced_memory()

        tracemalloc.stop()
        
        # memory[i]=peak /1024/1024
        time_list.append(during)
        
    
        # print('memory blocks peak:{:>10.4f} MB'.format(memory[i]))
        print('time: {:.4f} s'.format(during))
        print('ARI:{}'.format(ari))
        print('NMI:{}'.format(nmi))
        print('AMI:{}'.format(ami))
        print('Homogeneity:{}'.format(homogeneity))
        print('chaos:{}'.format(chaos))
        print('pas:{}'.format(pas))
        print('asw:{}'.format(asw))