# Mouse mPFC data by STARmap

In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
from scipy.io import mmread
import os
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics.cluster import adjusted_rand_score
import seaborn as sns
import gc
import warnings
warnings.filterwarnings("ignore")
import random
import time
import STAGATE

import os
os.environ["R_HOME"] = r"D:\\download\\R\\R-4.1.3" 
os.environ["PATH"] = r"D:\\download\\R\\R-4.1.3\\bin\\x64" + ";" + os.environ["R_HOME"] 

In [None]:
ids_list = ['20180417_BZ5_control','20180419_BZ9_control','20180424_BZ14_control']

for cluster_number in np.arange(10,21,1): 
    print(cluster_number)
    with open(f"../result/STAGATE/result_STAGATE_k={cluster_number}.txt", "w") as f:
        f.write("sample\tseed\tari_value\n")
        for ids in ids_list:
            counts = pd.read_csv(f'../data/count_{ids}.csv',index_col=0)
            xy = pd.read_csv(f'../data/xy_{ids}.csv',index_col=0)
            gt = pd.read_csv(f'../data/gt_c_{ids}.csv',index_col=0)
            ground_truth = [str(gt.iloc[i,:][0]) for i in range(len(gt))]

            adata=ad.AnnData(counts.T,dtype='float64')
            adata.obs['array_row'] = xy['x']
            adata.obs['array_col'] = xy['y']
            adata.obs['ground_truth'] = np.array(gt)
            adata.var['gene_ids'] = counts.index
            adata.obsm['spatial'] = np.array(xy)

            sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata)

            # Constructing the spatial network
            STAGATE.Cal_Spatial_Net(adata,rad_cutoff=400)
            STAGATE.Stats_Spatial_Net(adata)

            df_result = pd.DataFrame()
            for seed in range(10):
                print('seed =',seed)
                # Running STAGATE
                tf.compat.v1.disable_eager_execution()
                adata_new = STAGATE.train_STAGATE(adata,alpha=0,random_seed=seed)

                # Run UMAP
                sc.pp.neighbors(adata_new,use_rep='STAGATE',random_state=seed)
                sc.tl.umap(adata_new)

                adata_new = STAGATE.mclust_R(adata_new,used_obsm='STAGATE',num_cluster=cluster_number,random_seed=seed)

                plt.rcParams['figure.figsize'] = (3,3)
                sc.pl.embedding(adata_new,basis='spatial',color='mclust',s=60,show=True,save=f'_STAGATE_{seed}.png')

                df_result[str(ids)+'_'+str(seed)] = adata_new.obs['mclust']

                ari_value = adjusted_rand_score(ground_truth,adata_new.obs['mclust'])
                f.write(f"{str(ids)}\t{str(seed)}\t{str(ari_value)}\n")
                print(f'seed = {seed},ari = {ari_value}')

            df_result.to_csv(f'../result/STAGATE/{ids}_k={cluster_number}.csv')