## Tutorial for ATAC to RNA Cross-Modality Generation Step2

Hi, this is the tutorial for ATAC to RNA(denoted as A2R) cross-modality generation. We take the MISAR-seq dataset as an example in this task. Detailed results about this task will be found in figure 3 in our paper.

To be specific, our model input is three types of data: 
* a set of paired spatial multi-omics data from the same slice (denoted as S1R and S1A in our paper and figure 1) 
* a single-modality spatial data from another slice (denoted as S2A in our paper and figure 1) 

Our target is to generate the missing modality for S2, that is S2R.

We divide the whole process into 2 files. This file is step 2, which is the core generation framework. If you do not run step 1, please go back to that file.


## Read in data

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import scanpy as sc

In [2]:
import torch
used_device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

#set working device
torch.cuda.set_device(used_device)

In [None]:
import os
os.chdir('/home/yourworkingpath/SpatialTranslator')

In [4]:
os.chdir('/workdir/sm2888/ST')

In [5]:
import SpaTranslator as spt


    _____               ______                           __        __              
   / ___/ ____   ____ _/_  __/_____ ____ _ ____   _____ / /____ _ / /_ ____   _____
   \__ \ / __ \ / __ `/ / /  / ___// __ `// __ \ / ___// // __ `// __// __ \ / ___/
  ___/ // /_/ // /_/ / / /  / /   / /_/ // / / /(__  )/ // /_/ // /_ / /_/ // /    
 /____// .___/ \__,_/ /_/  /_/    \__,_//_/ /_//____//_/ \__,_/ \__/ \____//_/     
      /_/                                                                           

SpaTranslator v1.0.8     



We read the After_train data from step 1.

In [6]:
S1_ATAC_data = sc.read_h5ad('After_train_E15_5-S1_atac.h5ad')

In [7]:
S1_RNA_data = sc.read_h5ad('To_train_E15_5-S1_expr.h5ad')

In [8]:
S2_ATAC_data = sc.read_h5ad('After_train_E15_5-S2_atac.h5ad')

In [9]:
S2_ATAC_data

AnnData object with n_obs × n_vars = 1939 × 244394
    obs: 'tsse', 'n_fragment', 'frac_dup', 'frac_mito', 'sample', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'Gex_RiboRatio', 'Gex_nUMI', 'Gex_nGenes', 'Gex_MitoRatio', 'BlacklistRatio', 'array_col', 'array_row', 'Combined_Clusters', 'RNA_Clusters', 'ATAC_Clusters', 'cell_type'
    uns: 'peaks', 'reference_sequences'
    obsm: 'AlignedEmbedding', 'insertion', 'spatial'

In [10]:
S2_ATAC_data.obs['cell_type'] = S2_ATAC_data.obs['cell_type'].astype('category')

## Preprocessing and Construct Neighbor Graph

In [11]:
n_neighbors=6
for adata in [S1_RNA_data, S1_ATAC_data, S2_ATAC_data]:
    spt.build_spatial_graph(adata,  knn_neighbors = n_neighbors, method ='KNN')

Constructing the spatial graph
Generated graph with 11694 edges across 1949 cells.
Average neighbors per cell: 6.0000
Constructing the spatial graph
Generated graph with 11694 edges across 1949 cells.
Average neighbors per cell: 6.0000
Constructing the spatial graph
Generated graph with 11634 edges across 1939 cells.
Average neighbors per cell: 6.0000


In [None]:
sc.pp.normalize_total(S1_RNA_data)
sc.pp.log1p(S1_RNA_data)
sc.pp.highly_variable_genes(S1_RNA_data, n_top_genes=5000)

RNA_data_processed = sc.read_h5ad('To_train_E15_5-S2_expr.h5ad')
sc.pp.normalize_total(RNA_data_processed)
sc.pp.log1p(RNA_data_processed)
sc.pp.highly_variable_genes(RNA_data_processed, n_top_genes=5000)

In [15]:
# get highly variable genes
hvg_S1 = S1_RNA_data.var[S1_RNA_data.var['highly_variable']].index
hvg_processed = RNA_data_processed.var[RNA_data_processed.var['highly_variable']].index

# get common highly variable genes
common_hvgs = hvg_S1.intersection(hvg_processed)

# filter data and only keep common highly variable genes
S1_RNA_data = S1_RNA_data[:, S1_RNA_data.var_names.isin(common_hvgs)]
RNA_data_processed = RNA_data_processed[:, RNA_data_processed.var_names.isin(common_hvgs)]


In [16]:
S1_RNA_data.var

Unnamed: 0,gene_ids,feature_types,genome,highly_variable,means,dispersions,dispersions_norm
Snhg6,ENSMUSG00000098234,Gene Expression,mm10,True,1.806389e-02,-0.300393,0.697886
Tcf24,ENSMUSG00000099032,Gene Expression,mm10,True,1.918491e-03,-0.317254,0.657134
Eya1,ENSMUSG00000025932,Gene Expression,mm10,True,1.314915e-01,-0.083303,2.602218
Msc,ENSMUSG00000025930,Gene Expression,mm10,True,8.504798e-03,0.473475,2.568211
Crispld1,ENSMUSG00000025776,Gene Expression,mm10,True,5.991552e-02,-0.079247,2.022538
...,...,...,...,...,...,...,...
Reps2,ENSMUSG00000040855,Gene Expression,mm10,True,2.175044e-01,-0.495685,0.997916
Mid1,ENSMUSG00000035299,Gene Expression,mm10,True,2.798040e-01,-0.596138,0.904185
mt-Co1,ENSMUSG00000064351,Gene Expression,mm10,True,6.870210e-01,-1.210864,1.165279
mt-Nd3,ENSMUSG00000064360,Gene Expression,mm10,True,8.321977e-02,-0.243758,1.147481


In [17]:
RNA_data_processed.var

Unnamed: 0,gene_ids,feature_types,genome,highly_variable,means,dispersions,dispersions_norm
Snhg6,ENSMUSG00000098234,Gene Expression,mm10,True,6.828864e-02,0.630215,1.180498
Tcf24,ENSMUSG00000099032,Gene Expression,mm10,True,4.771582e-03,0.492146,0.838558
Eya1,ENSMUSG00000025932,Gene Expression,mm10,True,1.845367e-01,0.393316,0.593794
Msc,ENSMUSG00000025930,Gene Expression,mm10,True,1.478886e-02,0.496599,0.849585
Crispld1,ENSMUSG00000025776,Gene Expression,mm10,True,8.068594e-02,0.569555,1.030269
...,...,...,...,...,...,...,...
Reps2,ENSMUSG00000040855,Gene Expression,mm10,True,3.789510e-01,0.446772,0.719084
Mid1,ENSMUSG00000035299,Gene Expression,mm10,True,4.929389e-01,0.530478,1.128766
mt-Co1,ENSMUSG00000064351,Gene Expression,mm10,True,1.814009e+00,0.788331,0.562708
mt-Nd3,ENSMUSG00000064360,Gene Expression,mm10,True,2.580433e-01,0.869502,3.226957


## Data Spliting and Model Loading

In [18]:
train_id, validation_id = spt.split_dataset_by_slices(S1_RNA_data, S1_ATAC_data)

Training set size: 1560
Validation set size: 389


In [19]:

model = spt.SpaTranslator()

SpaTranslator model initialized.


In [20]:
model.load_data(S1_RNA_data, S1_ATAC_data, S2_ATAC_data, train_id, validation_id, mode='A2R') 

Data successfully loaded.


In [21]:
model.preprocess_data(normalize_total=False, log1p=False,use_hvg=False) 

Recommended to use default settings for optimal results.
Binarizing ATAC data.
Filtering out peaks present in fewer than 0.50% of cells.
Applying TF-IDF transformation.
Normalizing data to range [0, 1] for ATAC.
Data preprocessing completed.


In [22]:
model.ATAC_data_p.var_names

Index(['chr1:3094810-3095311', 'chr1:3120251-3120752', 'chr1:3184960-3185461',
       'chr1:3360849-3361350', 'chr1:3399870-3400371', 'chr1:3414009-3414510',
       'chr1:3514678-3515179', 'chr1:3552386-3552887', 'chr1:3670824-3671325',
       'chr1:3671548-3672049',
       ...
       'chrY:90799093-90799594', 'chrY:90800223-90800724',
       'chrY:90801342-90801843', 'chrY:90803330-90803831',
       'chrY:90807496-90807997', 'chrY:90808593-90809094',
       'chrY:90810719-90811220', 'chrY:90811533-90812034',
       'chrY:90812148-90812649', 'chrY:90812685-90813186'],
      dtype='object', name='index', length=265014)

Similar to [scbutterfly](https://www.nature.com/articles/s41467-024-47418-x), we discover that it is a good idea to divide the ATAC dataset into several parts based on their location on the genome, which makes it easier for training. 

In [23]:
chrom_list = []
current_chrom = None
count = 0

for peak in model.ATAC_data_p.var_names:
    chrom = peak.split(':')[0]  #get the chromosome
    if chrom != current_chrom:
        if current_chrom is not None:
            chrom_list.append(count)
        current_chrom = chrom
        count = 1
    else:
        count += 1

# add the last chromosome
chrom_list.append(count)

print(chrom_list)

[16883, 12699, 19116, 11759, 10537, 9448, 12005, 8442, 12310, 8423, 7167, 22138, 11966, 19434, 18302, 13699, 17873, 15369, 14837, 2576, 31]


In [24]:
sum(chrom_list)

265014

Because of the number of our cell is quite small (only 1000+), we need to do some augmentation.

In [25]:
model.augment_data(aug_type="cell_type_augmentation")

Performing cell type-based augmentation
Data augmentation completed.


## Training

In [26]:
model.construct_model(chrom_list=chrom_list)


------------------------------
Model Parameters
Mode: A2R
R_encoder_nlayer: 2
A_encoder_nlayer: 2
R_decoder_nlayer: 2
A_decoder_nlayer: 2
R_encoder_dim_list: [1217, 256, 128]
A_encoder_dim_list: [265014, 672, 128]
R_decoder_dim_list: [128, 256, 1217]
A_decoder_dim_list: [128, 672, 265014]
R_encoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
A_encoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
R_decoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
A_decoder_act_list: [LeakyReLU(negative_slope=0.01), Sigmoid()]
Translator embed dim: 128
Translator input dim (RNA): 128
Translator input dim (ATAC): 128
Translator activation list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
Discriminator layers: 1
Discriminator dim list (RNA): [128]
Discriminator dim list (ATAC): [128]
Discriminator activation list: [Sigmoid()]
Dropout rate: 0.1
R_noise_rate: 0.

In [27]:
model.train_model(translator_epoch = 300)

RNA pretraining with VAE: 100%|█████████| 100/100 [09:50<00:00,  5.91s/it, train=0.0202, val=0.0222]
ATAC pretraining with VAE: 100%|████████| 100/100 [16:25<00:00,  9.86s/it, train=0.1123, val=0.1093]
Integrative training:  68%|████████▊    | 204/300 [49:11<24:09, 15.10s/it, train=0.4976, val=0.4905]

Integrative training early stop, validation loss does not improve in 50 epoches!


Integrative training:  68%|████████▊    | 204/300 [49:12<23:09, 14.47s/it, train=0.4976, val=0.4905]


Model training completed.


## Get prediction

You could get cross-modal predictions using ``model.test_model`` with default settings. We also provided more information metrics in this function, see in [API](../../API/index.html).

In [28]:
A2R_predict = model.test_model()

ATAC to RNA predicting...: 100%|████████████████████████████████████| 31/31 [00:01<00:00, 25.50it/s]


In [29]:
combined_tensor = torch.cat(A2R_predict.uns['A2R_embedding'], dim=0)

In [30]:
A2R_predict.uns['A2R_embedding'] = combined_tensor

In [31]:
A2R_predict.obsm['A2R_embedding'] = A2R_predict.uns['A2R_embedding'].numpy()

In [32]:
A2R_predict = A2R_predict[RNA_data_processed.obs_names].copy()

In [33]:
A2R_predict.obsm['spatial'] = RNA_data_processed.obsm['spatial']

In [34]:
A2R_predict

AnnData object with n_obs × n_vars = 1939 × 1217
    obs: 'tsse', 'n_fragment', 'frac_dup', 'frac_mito', 'sample', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'Gex_RiboRatio', 'Gex_nUMI', 'Gex_nGenes', 'Gex_MitoRatio', 'BlacklistRatio', 'array_col', 'array_row', 'Combined_Clusters', 'RNA_Clusters', 'ATAC_Clusters', 'cell_type'
    var: 'gene_ids', 'feature_types', 'genome', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'A2R_embedding'
    obsm: 'A2R_embedding', 'spatial'

We save our prediction result as A2R_predict

In [37]:
# check uns dictionary
print(A2R_predict.uns.keys())

if 'A2R_embedding' in A2R_predict.uns:
    # convert tensor to numpy array
    if isinstance(A2R_predict.uns['A2R_embedding'], torch.Tensor):
        A2R_predict.uns['A2R_embedding'] = A2R_predict.uns['A2R_embedding'].numpy()

# check uns dictionary
print(A2R_predict.uns.keys())

#save the result
A2R_predict.write_h5ad("./result/A2R_predict_E15_S1_pair_E15_S2_A2R.h5ad")

odict_keys(['A2R_embedding'])
odict_keys(['A2R_embedding'])


In [38]:
RNA_data_processed.write_h5ad("./result/Real_RNA_data_processed_E15_S1_pair_E15_S2_A2R.h5ad")