### Notes: run data_generation_all.py to reproduce {generated_data}

In [None]:
!python data_generation_all.py

### 0. import packages and select GPU if accessible

In [None]:
import os
import sys
import matplotlib
matplotlib.use('Agg')
#matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
#import pylab as pl
#from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from sklearn import metrics
from sklearn.metrics import adjusted_rand_score
from scipy import sparse
#from sklearn.metrics import roc_curve, auc, roc_auc_score
from st_loading_utils import load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP, load_DLPFC, load_BC, load_mVC
import numpy as np
import pickle
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, ChebConv, GATConv, DeepGraphInfomax, global_mean_pool, global_max_pool  # noqa
from torch_geometric.data import Data, DataLoader
from datetime import datetime
import argparse

In [None]:
parser = argparse.ArgumentParser()
# ================Specify data type firstly===============
parser.add_argument( '--data_type', default='nsc', help='"sc" or "nsc", \
   refers to single cell resolution datasets(e.g. MERFISH) and \
   non single cell resolution data(e.g. ST) respectively') 
# =========================== args ===============================
parser.add_argument( '--data_name', type=str, default='V1_Breast_Cancer_Block_A_Section_1', help="'MERFISH' or 'V1_Breast_Cancer_Block_A_Section_1") 
parser.add_argument( '--lambda_I', type=float, default=0.3) #0.8 on MERFISH, 0.3 on ST
parser.add_argument( '--data_path', type=str, default='generated_data/', help='data path')
parser.add_argument( '--model_path', type=str, default='model') 
parser.add_argument( '--embedding_data_path', type=str, default='Embedding_data') 
parser.add_argument( '--result_path', type=str, default='results') 
parser.add_argument( '--DGI', type=int, default=1, help='run Deep Graph Infomax(DGI) model, otherwise direct load embeddings')
parser.add_argument( '--load', type=int, default=0, help='Load pretrained DGI model')
parser.add_argument( '--num_epoch', type=int, default=5000, help='numebr of epoch in training DGI')
parser.add_argument( '--hidden', type=int, default=256, help='hidden channels in DGI') 
parser.add_argument( '--PCA', type=int, default=1, help='run PCA or not')   
parser.add_argument( '--cluster', type=int, default=1, help='run cluster or not')
parser.add_argument( '--n_clusters', type=int, default=5, help='number of clusters in Kmeans, when ground truth label is not avalible.') #5 on MERFISH, 20 on Breast
parser.add_argument( '--draw_map', type=int, default=1, help='run drawing map')
parser.add_argument( '--diff_gene', type=int, default=0, help='Run differential gene expression analysis')
parser.add_argument( '--batch_size', type=int, default=512, help='training batch size')
parser.add_argument( '--gpu_id', type=str, default="2", help='default gpu id')
args = parser.parse_args()
iters=2 # for script testing
# iters = 20 # for boxplotting
args.embedding_data_path = './CCST/generated_data'

### 1. DLPFC dataset (12 slides)

change '${dir_}' to  'path/to/your/DLPFC/data'

In [None]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]  # 7

   args.data_name = setting_combi[1]  # '151673'
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/DLPFC12'
   ad = load_DLPFC(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      # print(preds)
      # print(obs_df['original_clusters'].to_list())
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())
      
      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('DLPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 2. BC/MA datasets (2 slides)

In [None]:
"""BC"""
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/BC'
   ad = load_BC(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('HBRC1 ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

In [None]:
"""load MA section"""
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/mMAMP'
   ad = load_mMAMP(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('mABC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 3. mVC/mPFC datasets (4 slides)

In [None]:
"""mVC"""
setting_combinations = [[7, 'STARmap_20180505_BY3_1k.h5ad']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/STARmap_mouse_visual_cortex'
   ad = load_mVC(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('mVC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

In [None]:
"""mPFC"""
setting_combinations = [[4, '20180417_BZ5_control'], [4, '20180419_BZ9_control'], [4, '20180424_BZ14_control']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/STARmap_mouse_PFC'
   ad = load_mPFC(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('mPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 4. mHypothalamus dataset (6 slides)

In [None]:
"""mHypo"""
setting_combinations = [[8, '-0.14'], [8, '-0.19'], [8, '-0.24'], [8, '-0.29']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/mHypothalamus'
   ad = load_mHypothalamus(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      # obs_df = ad.obs
      # print(obs_df)
      # print(np.array(preds).shape)
      ARI = adjusted_rand_score(np.array(preds)[:, 1], ad.obs['original_clusters'].to_list())
      # exit(-1)
      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('mHypothalamus' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

### 5. Her2Tumor dataset (8 slides)

In [None]:
"""Her2"""
setting_combinations = [[6, 'A1'], [5, 'B1'], [4, 'C1'], [4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
for setting_combi in setting_combinations:
   args.n_clusters = setting_combi[0]

   args.data_name = setting_combi[1]
   dataset = setting_combi[1]
   args.data_type = 'nsc'
   dir_ = './benchmarking_data/Her2_tumor'
   ad = load_her2_tumor(root_dir=dir_, section_id=args.data_name)
   aris = []
   args.embedding_data_path = args.embedding_data_path +'/'+ args.data_name +'/'
   args.model_path = args.model_path +'/'+ args.data_name +'/'
   args.result_path = args.result_path +'/'+ args.data_name +'/'
   if not os.path.exists(args.embedding_data_path):
      os.makedirs(args.embedding_data_path) 
   if not os.path.exists(args.model_path):
      os.makedirs(args.model_path) 
   args.result_path = args.result_path+'lambdaI'+str(args.lambda_I) +'/'
   if not os.path.exists(args.result_path):
      os.makedirs(args.result_path) 
   print ('------------------------Model and Training Details--------------------------')
   print(args) 
   
   for iter_ in range(iters):

      
      if args.data_type == 'sc': # should input a single cell resolution dataset, e.g. MERFISH
         from CCST_merfish_utils import CCST_on_MERFISH
         CCST_on_MERFISH(args)
      elif args.data_type == 'nsc': # should input a non-single cell resolution dataset, e.g. V1_Breast_Cancer_Block_A_Section_1
         from CCST_ST_utils import CCST_on_ST
         preds = CCST_on_ST(args)
      else:
         print('Data type not specified')

      # calculate metric ARI
      obs_df = ad.obs.dropna()
      ARI = adjusted_rand_score(np.array(preds)[:, 1], obs_df['original_clusters'].to_list())

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('ccst_aris.txt', 'a+') as fp:
      fp.write('Her2tumor' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')