In [1]:
seed_num = 2022

from numpy.random import seed
seed(seed_num)
from tensorflow.random import set_seed
set_seed(seed_num)
import tensorflow as tf
tf.keras.utils.set_random_seed(seed_num)
tf.config.experimental.enable_op_determinism()


import os
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from IPython.display import clear_output
from time import time
import anndata as ad
import seaborn as sns
import scanpy as sc
import pylab as pl

from SC2Spa import SI, PP, Vis

pd.set_option('display.max_columns', None)
%matplotlib inline

/mnt/win1
/mnt/win1/SC2Spa_Notebook/MouseEmbryo_SSV2


In [2]:
#Load
adata_ref = ad.read_h5ad('../Dataset/AdataEmbryo1.h5ad')
adata_query = ad.read_h5ad('../Dataset/MOCA_12_5.h5ad')

adata_ref.var_names = adata_ref.var_names.str.upper()
adata_query.var_names = adata_query.var_names.str.upper()

adata_ref.var_names_make_unique()
adata_query.var_names_make_unique()

#Normalize
sc.pp.normalize_total(adata_ref, target_sum=1e4)
sc.pp.normalize_total(adata_query, target_sum=1e4)

sc.pp.log1p(adata_ref)
sc.pp.log1p(adata_query)

Anno = pd.read_csv('slideSeq_Puck190926_03_RCTD.csv', index_col = 0)

Anno['MCT'] = 't'

index1 = Anno.index[(Anno['celltype_1'] == Anno['celltype_2'])]
Anno['MCT'][index1] = Anno['celltype_1'][index1]
index2 = Anno.index[(Anno['celltype_1'] != Anno['celltype_2'])]
Anno['MCT'][index2] = Anno['celltype_1'][index2] + '_' + Anno['celltype_2'][index2].apply(lambda x: '_'.join(sorted(set(x.split('_')))))

adata_ref.obs = adata_ref.obs.merge(Anno, left_index = True, right_index = True, how = 'left')

adata_ref.obsm['spatial'] = adata_ref.obs[['xcoord', 'ycoord']].values

In [3]:
sta = time()
JGs, WDs = SI.WassersteinD(adata_ref, adata_query, sparse = True,
                           WD_cutoff = 0.1, root = 'WDs/', save = 'WDs_T1')

end = time()
print((end - sta) / 60.0, 'min')

137.4829457084338 min


In [None]:
WD_cutoff = 0.4

root = 'WDs/'
save = 'WDs_T1'

WDs = pd.read_csv(root + save + '.csv')
JGs = sorted(WDs[WDs['Wasserstein_Distance'] < WD_cutoff]['Gene'].tolist())

adata_ref = adata_ref[:, JGs]
adata_query = adata_query[:, JGs]

seed(seed_num)
set_seed(seed_num)
tf.keras.utils.set_random_seed(seed_num)

sta = time()

neighbors, dis = SI.FineMapping(adata_ref, adata_query, sparse =True, model_path = None, root = 'Model_SI/',
                                name = 'SI_T1_WD', l1_reg = 1e-5, l2_reg = 0, dropout = 0.05, epoch = 500,
                                batch_size = 4096, nodes = [4096, 1024, 256, 64, 16, 4], lrr_patience = 20,
                                ES_patience = 50, min_lr = 1e-5, save = True, polar = True,
                                n_neighbors = 1000, dis_cutoff = 5, seed = seed_num)

end = time()
print((end - sta) / 60.0, 'min')

'''
neighbors, dis = SI.FineMapping(adata_ref, adata_query, sparse =True,
                                model_path = 'Model_SI/SI_T1_WD.h5', WD_cutoff = None, 
                                polar = True, n_neighbors = 1000, dis_cutoff = 5)
'''

In [3]:
neighbors, dis = SI.FineMapping(adata_ref, adata_query, sparse =True, model_path = 'Model_SI/SI_T1.h5',
                                polar = True, n_neighbors = 1000, dis_cutoff = 20)


n of Referece Genes: 20220
n of Target Genes: 26157
n of Joint Genes: 18777
(41406, 18777)
(259627, 18777)


2022-05-01 20:48:02.727626: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-01 20:48:03.390206: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21942 MB memory:  -> device: 0, name: NVIDIA TITAN RTX, pci bus id: 0000:3b:00.0, compute capability: 7.5


In [8]:
adata_query.obs.columns

Index(['all_exon_count', 'all_intron_count', 'all_read_count',
       'intergenic_rate', 'embryo_id', 'embryo_sex', 'nuclei_extraction_date',
       'development_stage', 'Total_mRNAs', 'num_genes_expressed',
       'Size_Factor', 'Main_Cluster', 'Main_cluster_tsne_1',
       'Main_cluster_tsne_2', 'Sub_cluster', 'Sub_cluster_tsne_1',
       'Sub_cluster_tsne_2', 'doublet_score', 'detected_doublet',
       'doublet_cluster', 'sub_cluster_id', 'Main_cell_type',
       'Main_trajectory', 'Main_trajectory_umap_1', 'Main_trajectory_umap_2',
       'Main_trajectory_umap_3', 'Main_trajectory_refined_by_cluster',
       'Main_trajectory_refined_umap_1', 'Main_trajectory_refined_umap_2',
       'Main_trajectory_refined_umap_3', 'Sub_trajectory_name',
       'Sub_trajectory_umap_1', 'Sub_trajectory_umap_2',
       'Sub_trajectory_louvain_component', 'Sub_trajectory_Pseudotime',
       'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt',
       'pct_counts_mt', 'total_counts_ercc', 

In [None]:
for CT in adata_query.obs['Main_cell_type'].unique():
    print('*'*16)
    print(CT)
    print('*'*16)
    print('Transfer:')
    Vis.DrawCT1(adata_query, CT = CT, c_name = 'Main_cell_type',
                   x_name = 'x_transfer', y_name = 'y_transfer',
                   root = 'Transfer2/FM_Valid2/', save = 'SC2HC1')

In [None]:
for CT in adata_ref.obs['celltype_1'].unique():
    print('*'*16)
    print(CT)
    print('*'*16)
    print('Transfer:')
    Vis.DrawCT1(adata_ref, CT = CT, FM = True, c_name = 'SSV2',
                   x_name = 'xcoord', y_name = 'ycoord',
                   root = 'Transfer2/FM_Valid2/', save = 'HC1')

In [13]:
sta = time()

adata_impute = SI.NRD_impute(neighbors, dis, adata_ref, adata_query,
                             ct_name='Main_cell_type', dis_min=0.1, exclude_CTs=None)

end = time()
print((end - sta) / 60, 'min(s)')

4.860832146803538 min(s)


In [None]:
cmap = sns.cubehelix_palette(n_colors = 32,start = 2, rot=1.5, as_cmap = True)

GeneInfo = pd.read_csv('GeneInfo_DS_CTS.csv', index_col = 0)

GLs = []

for col in GeneInfo.columns[GeneInfo.columns.str.contains('norm')]:
    
    GL = GeneInfo.sort_values(col, ascending = False)[:15].index.tolist()
    GLs.extend(GL)
    
GLs = list(set(GLs))

for gene in GLs:
    print('*'*32)
    print(gene)
    print('*'*32)
    print('scRNAseq:')
    Vis.DrawGenes2(adata_query, gene = gene, colorbar = False, lim = False, 
                   xlim = [650, 5750], ylim = [650, 5750], cmap = cmap,
                   FM = True, CTL = None, c_name = 'simp_name', root = 'Transfer2/FM_Valid1/',
                   s = 2, x_name = 'x_transfer', y_name = 'y_transfer', title = False, save = 'AMB')
    print('ST:')
    Vis.DrawGenes2(adata_ref, gene = gene, colorbar = False, lim = True,
                   xlim = [650, 5750], ylim = [650, 5750], cmap = cmap,
                   FM = True, CTL = None, c_name = 'SSV2', root = 'Transfer2/FM_Valid1/', 
                   s = 2, x_name = 'xcoord', y_name = 'ycoord', title = False, save = 'HC1')
    print('Imputed ST:')
    Vis.DrawGenes2(adata_impute, gene = gene, colorbar = False, lim = True,
                   xlim = [650, 5750], ylim = [650, 5750], cmap = cmap,
                   FM = True, CTL = None, c_name = 'SSV2', root = 'Transfer2/FM_Valid1/', 
                   s = 2, x_name = 'xcoord', y_name = 'ycoord', title = False, save = 'HC1_Impute')