## Tutorial for RNA to ATAC Cross-Modality Generation Step2

Hi, this is the tutorial for RNA to ATAC (denoted as R2A) cross-modality generation. We take the MISAR-seq dataset as an example in this task. Detailed results about this task will be found in figure 3 in our paper.

To be specific, our model input is three types of data: 
* a set of paired spatial multi-omics data from the same slice (denoted as S1R and S1A in our paper and figure 1) 
* a single-modality spatial data from another slice (denoted as S2R in our paper and figure 1) 

Our target is to generate the missing modality for S2, that is S2A.

We divide the whole process into 2 files. This file is step 2, which is the core generation framework. If you do not run step 1, please go back to that file.


## Read in data

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import scanpy as sc

In [2]:
import torch
used_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#set working device
torch.cuda.set_device(used_device)

In [None]:
import os
os.chdir('/home/yurworkingpth/SpatialTranslator')

In [5]:
os.chdir('/workdir/sm2888/ST')

In [6]:
import SpaTranslator as spt


    _____               ______                           __        __              
   / ___/ ____   ____ _/_  __/_____ ____ _ ____   _____ / /____ _ / /_ ____   _____
   \__ \ / __ \ / __ `/ / /  / ___// __ `// __ \ / ___// // __ `// __// __ \ / ___/
  ___/ // /_/ // /_/ / / /  / /   / /_/ // / / /(__  )/ // /_/ // /_ / /_/ // /    
 /____// .___/ \__,_/ /_/  /_/    \__,_//_/ /_//____//_/ \__,_/ \__/ \____//_/     
      /_/                                                                           

SpaTranslator v1.0.8     



We read the After_train data from step 1.

In [8]:
S1_ATAC_data = sc.read_h5ad('filtered_merged_E15_5-S1_atac.h5ad')

In [9]:
S1_ATAC_data.X.toarray()

array([[ 0,  4, 22, ...,  0,  0,  0],
       [ 0,  2, 19, ...,  0,  0,  0],
       [ 0,  2, 25, ...,  0,  0,  0],
       ...,
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0]])

In [10]:
S1_ATAC_data

AnnData object with n_obs × n_vars = 1949 × 70249
    obs: 'gex_barcode', 'atac_barcode', 'is_cell', 'excluded_reason', 'gex_raw_reads', 'gex_mapped_reads', 'gex_conf_intergenic_reads', 'gex_conf_exonic_reads', 'gex_conf_intronic_reads', 'gex_conf_exonic_unique_reads', 'gex_conf_exonic_antisense_reads', 'gex_conf_exonic_dup_reads', 'gex_exonic_umis', 'gex_conf_intronic_unique_reads', 'gex_conf_intronic_antisense_reads', 'gex_conf_intronic_dup_reads', 'gex_intronic_umis', 'gex_conf_txomic_unique_reads', 'gex_umis_count', 'gex_genes_count', 'atac_raw_reads', 'atac_unmapped_reads', 'atac_lowmapq', 'atac_dup_reads', 'atac_chimeric_reads', 'atac_mitochondrial_reads', 'atac_fragments', 'atac_TSS_fragments', 'atac_peak_region_fragments', 'atac_peak_region_cutsites', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'Gex_RiboRatio', 'Gex_nUMI', 'Gex_nGenes', 'Gex_MitoR

In [11]:
S1_RNA_data = sc.read_h5ad('After_train_E15_5-S1_expr.h5ad')

In [12]:
S1_RNA_data.obs

Unnamed: 0,gex_barcode,atac_barcode,is_cell,excluded_reason,gex_raw_reads,gex_mapped_reads,gex_conf_intergenic_reads,gex_conf_exonic_reads,gex_conf_intronic_reads,gex_conf_exonic_unique_reads,...,array_col,array_row,Harmony_ATAC_0.35,Harmony_RNA_0.7,Harmony_Combined_1.2_mergeCortex,ReadsInPeaks,FRIP,Annotation_for_Combined,Ground Truth,cell_type
GCCGCAACGCCGCAAC-1,GCCGCAACGCCGCAAC-1,GCCGCAACGCCGCAAC-1,1,0,200903,184493,21684,53624,99766,47535,...,24,27,1,3,6,251389,0.411932,Cartilage_2,Cartilage_2,Cartilage_2
GCCATTCTGCCATTCT-1,GCCATTCTGCCATTCT-1,GCCATTCTGCCATTCT-1,1,0,332851,305088,44055,92452,149887,82154,...,34,17,10,3,13,203006,0.342920,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain
TGCGGACCTGCGGACC-1,TGCGGACCTGCGGACC-1,TGCGGACCTGCGGACC-1,1,0,285439,264247,40488,76256,129985,67674,...,22,29,1,3,10,232649,0.405886,Thalamus,Thalamus,Thalamus
ACTATGCAACTATGCA-1,ACTATGCAACTATGCA-1,ACTATGCAACTATGCA-1,0,1,179671,164554,28576,51311,71767,45917,...,13,38,4,11,8,214881,0.396050,Mesenchyme,Mesenchyme,Mesenchyme
TAGATCTATAGATCTA-1,TAGATCTATAGATCTA-1,TAGATCTATAGATCTA-1,1,0,305439,280980,38357,86695,138397,77058,...,20,31,10,3,13,206150,0.382904,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TGAGTCTGTCCGGCCT-1,TGAGTCTGTCCGGCCT-1,TGAGTCTGTCCGGCCT-1,0,0,3103,2873,509,1042,1110,948,...,23,45,10,3,14,1190,0.372108,Midbrain,Midbrain,Midbrain
GATCTTACATAAGGAG-1,GATCTTACATAAGGAG-1,GATCTTACATAAGGAG-1,0,0,933,861,167,293,311,256,...,27,1,4,13,7,1074,0.359920,Cartilage_3,Cartilage_3,Cartilage_3
TCCGGCCTATAAGGAG-1,TCCGGCCTATAAGGAG-1,TCCGGCCTATAAGGAG-1,0,0,829,761,186,207,310,185,...,6,1,13,13,7,1095,0.374743,Cartilage_3,Cartilage_3,Cartilage_3
AGAGAAGGATAAGGAG-1,AGAGAAGGATAAGGAG-1,AGAGAAGGATAAGGAG-1,0,0,816,755,126,255,286,228,...,5,1,4,13,7,971,0.362313,Cartilage_3,Cartilage_3,Cartilage_3


In [13]:
S1_ATAC_data.obs

Unnamed: 0,gex_barcode,atac_barcode,is_cell,excluded_reason,gex_raw_reads,gex_mapped_reads,gex_conf_intergenic_reads,gex_conf_exonic_reads,gex_conf_intronic_reads,gex_conf_exonic_unique_reads,...,BlacklistRatio,array_col,array_row,Harmony_ATAC_0.35,Harmony_RNA_0.7,Harmony_Combined_1.2_mergeCortex,ReadsInPeaks,FRIP,Annotation_for_Combined,cell_type
GCCGCAACGCCGCAAC-1_E15_5-S1,GCCGCAACGCCGCAAC-1,GCCGCAACGCCGCAAC-1,1,0,200903,184493,21684,53624,99766,47535,...,0.055156,24,27,1,3,6,251389,0.411932,Cartilage_2,Cartilage_2
GCCATTCTGCCATTCT-1_E15_5-S1,GCCATTCTGCCATTCT-1,GCCATTCTGCCATTCT-1,1,0,332851,305088,44055,92452,149887,82154,...,0.048516,34,17,10,3,13,203006,0.342920,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain
TGCGGACCTGCGGACC-1_E15_5-S1,TGCGGACCTGCGGACC-1,TGCGGACCTGCGGACC-1,1,0,285439,264247,40488,76256,129985,67674,...,0.051738,22,29,1,3,10,232649,0.405886,Thalamus,Thalamus
ACTATGCAACTATGCA-1_E15_5-S1,ACTATGCAACTATGCA-1,ACTATGCAACTATGCA-1,0,1,179671,164554,28576,51311,71767,45917,...,0.049135,13,38,4,11,8,214881,0.396050,Mesenchyme,Mesenchyme
TAGATCTATAGATCTA-1_E15_5-S1,TAGATCTATAGATCTA-1,TAGATCTATAGATCTA-1,1,0,305439,280980,38357,86695,138397,77058,...,0.054035,20,31,10,3,13,206150,0.382904,Diencephalon_and_hindbrain,Diencephalon_and_hindbrain
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TGAGTCTGTCCGGCCT-1_E15_5-S1,TGAGTCTGTCCGGCCT-1,TGAGTCTGTCCGGCCT-1,0,0,3103,2873,509,1042,1110,948,...,0.069375,23,45,10,3,14,1190,0.372108,Midbrain,Midbrain
GATCTTACATAAGGAG-1_E15_5-S1,GATCTTACATAAGGAG-1,GATCTTACATAAGGAG-1,0,0,933,861,167,293,311,256,...,0.042560,27,1,4,13,7,1074,0.359920,Cartilage_3,Cartilage_3
TCCGGCCTATAAGGAG-1_E15_5-S1,TCCGGCCTATAAGGAG-1,TCCGGCCTATAAGGAG-1,0,0,829,761,186,207,310,185,...,0.045517,6,1,13,13,7,1095,0.374743,Cartilage_3,Cartilage_3
AGAGAAGGATAAGGAG-1_E15_5-S1,AGAGAAGGATAAGGAG-1,AGAGAAGGATAAGGAG-1,0,0,816,755,126,255,286,228,...,0.053731,5,1,4,13,7,971,0.362313,Cartilage_3,Cartilage_3


In [14]:
S2_RNA_data = sc.read_h5ad('After_train_E15_5-S2_expr.h5ad')

In [15]:
S2_RNA_data.obs['cell_type'] = S2_RNA_data.obs['cell_type'].astype('category')

In [16]:
S2_RNA_data

AnnData object with n_obs × n_vars = 1939 × 32285
    obs: 'gex_barcode', 'atac_barcode', 'is_cell', 'excluded_reason', 'gex_raw_reads', 'gex_mapped_reads', 'gex_conf_intergenic_reads', 'gex_conf_exonic_reads', 'gex_conf_intronic_reads', 'gex_conf_exonic_unique_reads', 'gex_conf_exonic_antisense_reads', 'gex_conf_exonic_dup_reads', 'gex_exonic_umis', 'gex_conf_intronic_unique_reads', 'gex_conf_intronic_antisense_reads', 'gex_conf_intronic_dup_reads', 'gex_intronic_umis', 'gex_conf_txomic_unique_reads', 'gex_umis_count', 'gex_genes_count', 'atac_raw_reads', 'atac_unmapped_reads', 'atac_lowmapq', 'atac_dup_reads', 'atac_chimeric_reads', 'atac_mitochondrial_reads', 'atac_fragments', 'atac_TSS_fragments', 'atac_peak_region_fragments', 'atac_peak_region_cutsites', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'Gex_RiboRatio', 'Gex_nUMI', 'Gex_nGenes', 'Gex_MitoR

## Preprocessing and Construct Neighbor Graph

In [17]:

n_neighbors=6
for adata in [S1_RNA_data, S1_ATAC_data, S2_RNA_data]:
    spt.build_spatial_graph(adata,  knn_neighbors = n_neighbors, method ='KNN')

train_id, validation_id = spt.split_dataset_by_slices(S1_RNA_data, S1_ATAC_data)



Constructing the spatial graph
Generated graph with 11694 edges across 1949 cells.
Average neighbors per cell: 6.0000
Constructing the spatial graph
Generated graph with 11694 edges across 1949 cells.
Average neighbors per cell: 6.0000
Constructing the spatial graph
Generated graph with 11634 edges across 1939 cells.
Average neighbors per cell: 6.0000
Training set size: 1560
Validation set size: 389


## Data Spliting and Model Loading

In [18]:
model = spt.SpaTranslator()

SpaTranslator model initialized.


In [19]:
model.load_data(S1_RNA_data, S1_ATAC_data, S2_RNA_data, train_id, validation_id,mode="R2A") 

Data successfully loaded.


In [20]:
model.preprocess_data() 

Applying total count normalization for RNA.
Applying log1p transformation for RNA.
Selecting top 3000 highly variable genes for RNA.
Binarizing ATAC data.
Filtering out peaks present in fewer than 0.50% of cells.
Applying TF-IDF transformation.
Normalizing data to range [0, 1] for ATAC.
Data preprocessing completed.


In [21]:
model.ATAC_data_p.var_names

Index(['chr1:3094792-3095311', 'chr1:3670633-3671325', 'chr1:3671531-3672049',
       'chr1:4491884-4492396', 'chr1:4492853-4493362', 'chr1:4493405-4493920',
       'chr1:4496366-4497015', 'chr1:4571647-4572150', 'chr1:4785487-4785998',
       'chr1:4807537-4808043',
       ...
       'chrY:90740951-90741813', 'chrY:90742388-90742908',
       'chrY:90743047-90743561', 'chrY:90743708-90744240',
       'chrY:90799093-90799597', 'chrY:90800221-90800724',
       'chrY:90803330-90803836', 'chrY:90807246-90807997',
       'chrY:90810536-90812649', 'chrY:90812651-90813186'],
      dtype='object', length=70249)

Similar to [scbutterfly](https://www.nature.com/articles/s41467-024-47418-x), we discover that it is a good idea to divide the ATAC dataset into several parts based on their location on the genome, which makes it easier for training. 

In [22]:
chrom_list = []
current_chrom = None
count = 0

for peak in model.ATAC_data_p.var_names:
    chrom = peak.split(':')[0]  # get the chromosome name
    if chrom != current_chrom:
        if current_chrom is not None:
            chrom_list.append(count)
        current_chrom = chrom
        count = 1
    else:
        count += 1

# add the last chromosome
chrom_list.append(count)
print(chrom_list)

[4246, 3244, 5610, 3031, 2706, 2405, 3114, 2099, 3297, 2166, 1992, 5890, 3021, 5216, 4744, 3425, 4923, 4082, 3969, 1050, 19]


In [23]:
sum(chrom_list)

70249

Because of the number of our cell is quite small (only 1000+), we need to do some augmentation.

In [24]:
model.augment_data(aug_type="cell_type_augmentation")

Performing cell type-based augmentation
Data augmentation completed.


In [25]:
model.construct_model(chrom_list=chrom_list)


------------------------------
Model Parameters
Mode: R2A
R_encoder_nlayer: 2
A_encoder_nlayer: 2
R_decoder_nlayer: 2
A_decoder_nlayer: 2
R_encoder_dim_list: [3000, 256, 128]
A_encoder_dim_list: [70249, 672, 128]
R_decoder_dim_list: [128, 256, 3000]
A_decoder_dim_list: [128, 672, 70249]
R_encoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
A_encoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
R_decoder_act_list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
A_decoder_act_list: [LeakyReLU(negative_slope=0.01), Sigmoid()]
Translator embed dim: 128
Translator input dim (RNA): 128
Translator input dim (ATAC): 128
Translator activation list: [LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01), LeakyReLU(negative_slope=0.01)]
Discriminator layers: 1
Discriminator dim list (RNA): [128]
Discriminator dim list (ATAC): [128]
Discriminator activation list: [Sigmoid()]
Dropout rate: 0.1
R_noise_rate: 0.5


In [26]:
model.train_model(translator_epoch = 300)

RNA pretraining with VAE: 100%|█████████| 100/100 [03:33<00:00,  2.14s/it, train=0.0652, val=0.0712]
ATAC pretraining with VAE: 100%|████████| 100/100 [08:52<00:00,  5.33s/it, train=0.2114, val=0.2074]
Integrative training:  62%|████████     | 186/300 [33:57<20:48, 10.95s/it, train=0.9907, val=0.9829]

Integrative training early stop, validation loss does not improve in 50 epoches!
Model training completed.





## Get prediction 

In [27]:
R2A_predict = model.test_model()

RNA to ATAC predicting...: 100%|████████████████████████████████████| 31/31 [00:00<00:00, 32.90it/s]


In [28]:
combined_tensor = torch.cat(R2A_predict.uns['R2A_embedding'], dim=0)

In [29]:
R2A_predict.uns['R2A_embedding'] = combined_tensor

In [30]:
R2A_predict.obsm['R2A_embedding'] = R2A_predict.uns['R2A_embedding'].numpy()

In [31]:
R2A_predict

AnnData object with n_obs × n_vars = 1939 × 70249
    obs: 'gex_barcode', 'atac_barcode', 'is_cell', 'excluded_reason', 'gex_raw_reads', 'gex_mapped_reads', 'gex_conf_intergenic_reads', 'gex_conf_exonic_reads', 'gex_conf_intronic_reads', 'gex_conf_exonic_unique_reads', 'gex_conf_exonic_antisense_reads', 'gex_conf_exonic_dup_reads', 'gex_exonic_umis', 'gex_conf_intronic_unique_reads', 'gex_conf_intronic_antisense_reads', 'gex_conf_intronic_dup_reads', 'gex_intronic_umis', 'gex_conf_txomic_unique_reads', 'gex_umis_count', 'gex_genes_count', 'atac_raw_reads', 'atac_unmapped_reads', 'atac_lowmapq', 'atac_dup_reads', 'atac_chimeric_reads', 'atac_mitochondrial_reads', 'atac_fragments', 'atac_TSS_fragments', 'atac_peak_region_fragments', 'atac_peak_region_cutsites', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'Gex_RiboRatio', 'Gex_nUMI', 'Gex_nGenes', 'Gex_MitoR

In [32]:
R2A_predict.obsm['spatial'] = S2_RNA_data.obsm['spatial']

In [33]:
# check uns dictionary
print(R2A_predict.uns.keys())

if 'R2A_embedding' in R2A_predict.uns:
    # convert tensor to numpy array
    if isinstance(R2A_predict.uns['R2A_embedding'], torch.Tensor):
        R2A_predict.uns['R2A_embedding'] = R2A_predict.uns['R2A_embedding'].numpy()

# check uns dictionary again
print(R2A_predict.uns.keys())

R2A_predict.write_h5ad("./result/R2A_predict_E15_S1_pair_E15_S2_R2A.h5ad")

odict_keys(['R2A_embedding'])
odict_keys(['R2A_embedding'])
