### 0. import packages and select GPU if accessible

In [1]:
import os
import torch
import argparse
import warnings
import numpy as np
import anndata
import scanpy as sc
import matplotlib.pyplot as plt
import pandas as pd
from src.graph_func import graph_construction
from src.utils_func import mk_dir, adata_preprocess, load_visium_sge
from src.SEDR_train import SEDR_Train
from sklearn.metrics import adjusted_rand_score
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP

warnings.filterwarnings('ignore')
torch.cuda.cudnn_enabled = False
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
# iters = 1 # for script testing

### 1. DLPFC dataset (12 slides)

change '${dir_}' to  'path/to/your/DLPFC/data'

In [None]:
iters = 20 # for boxplotting
parser = argparse.ArgumentParser()
parser.add_argument('--k', type=int, default=10, help='parameter k in spatial graph')
parser.add_argument('--knn_distanceType', type=str, default='euclidean',
                    help='graph distance type: euclidean/cosine/correlation')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--cell_feat_dim', type=int, default=200, help='Dim of PCA')
parser.add_argument('--feat_hidden1', type=int, default=100, help='Dim of DNN hidden 1-layer.')
parser.add_argument('--feat_hidden2', type=int, default=20, help='Dim of DNN hidden 2-layer.')
parser.add_argument('--gcn_hidden1', type=int, default=32, help='Dim of GCN hidden 1-layer.')
parser.add_argument('--gcn_hidden2', type=int, default=8, help='Dim of GCN hidden 2-layer.')
parser.add_argument('--p_drop', type=float, default=0.2, help='Dropout rate.')
parser.add_argument('--using_dec', type=bool, default=True, help='Using DEC loss.')
parser.add_argument('--using_mask', type=bool, default=False, help='Using mask for multi-dataset.')
parser.add_argument('--feat_w', type=float, default=10, help='Weight of DNN loss.')
parser.add_argument('--gcn_w', type=float, default=0.1, help='Weight of GCN loss.')
parser.add_argument('--dec_kl_w', type=float, default=10, help='Weight of DEC loss.')
parser.add_argument('--gcn_lr', type=float, default=0.01, help='Initial GNN learning rate.')
parser.add_argument('--gcn_decay', type=float, default=0.01, help='Initial decay rate.')
parser.add_argument('--dec_cluster_n', type=int, default=10, help='DEC cluster number.')
parser.add_argument('--dec_interval', type=int, default=20, help='DEC interval nnumber.')
parser.add_argument('--dec_tol', type=float, default=0.00, help='DEC tol.')
# ______________ Eval clustering Setting _________
parser.add_argument('--eval_resolution', type=int, default=1, help='Eval cluster number.')
parser.add_argument('--eval_graph_n', type=int, default=20, help='Eval graph kN tol.') 

params = parser.parse_args()
params.device = device


def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
    '''
        arg1(adata)[AnnData matrix]
        arg2(fixed_clus_count)[int]
        
        return:
            resolution[int]
    '''
    for res in sorted(list(np.arange(0.2, 2.5, increment)), reverse=True):
        sc.tl.leiden(adata, random_state=0, resolution=res)
        count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
        if count_unique_leiden == fixed_clus_count:
            break
    return res

In [None]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  # '151673'
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/DLPFC12'
   adata_h5 = load_DLPFC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('DLPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 2. BC/MA datasets (2 slides)

In [None]:
"""BC"""
# the number of clusters
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  #
   save_fold = os.path.join('./output/', dataset)
   dir_ = '/home/yunfei/spatial_benchmarking/benchmarking_data/BC'
   adata_h5 = load_BC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('HBRC1 ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

In [None]:
"""load mMAMP ma section"""
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/mMAMP'
   adata_h5 = load_mMAMP(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mABC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 3. mVC/mPFC datasets (4 slides)

In [None]:
"""mVC"""
setting_combinations = [[7, 'STARmap_20180505_BY3_1k.h5ad']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/STARmap_mouse_visual_cortex'
   adata_h5 = load_mVC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mVC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

In [None]:
"""mPFC"""
# the number of clusters
setting_combinations = [[4, '20180417_BZ5_control'], [4, '20180419_BZ9_control'], [4, '20180424_BZ14_control']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/STARmap_mouse_PFC'
   adata_h5 = load_mPFC(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 4. mHypothalamus dataset (6 slides)

In [None]:
"""mHypo"""
setting_combinations = [[8, '-0.04'], [8, '-0.09'], [8, '-0.14'], [8, '-0.19'], [8, '-0.24'], [8, '-0.29']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/mHypothalamus'
   adata_h5 = load_mHypothalamus(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mHypothalamus' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 5. Her2Tumor dataset (8 slides)

In [None]:
"""Her2"""
setting_combinations = [[6, 'A1'], [5, 'B1'], [4, 'C1'], [4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  # '151673'
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/Her2_tumor'
   adata_h5 = load_her2_tumor(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):

      
      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('Her2tumor' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')