# Transformer Model

My idea here is to use the TF expression, RE availability, and the distance from RE to TG TSS as the input to a transformer model which predicts TG expression. The output of the transformer will be the predicted expression of each TG, which will be compared to the expression in the dataset.

In [1]:
!hostnamectl

   Static hostname: psh01com1hcom35
         Icon name: computer-server
           Chassis: server
        Machine ID: 6860da98c8574f44be8f2ea25abdb7fb
           Boot ID: bbdde321ce964aff96f74d2bfdfe077c
  Operating System: ]8;;https://www.redhat.com/Red Hat Enterprise Linux 8.10 (Ootpa)]8;;
       CPE OS Name: cpe:/o:redhat:enterprise_linux:8::baseos
            Kernel: Linux 4.18.0-553.22.1.el8_10.x86_64
      Architecture: x86-64


In [2]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import pybedtools
from grn_inference import utils

torch.manual_seed(1)
np.random.seed(42)

project_dir = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
mm10_genome_dir = os.path.join(project_dir, "data/reference_genome/mm10")
mm10_gene_tss_file = os.path.join(project_dir, "data/genome_annotation/mm10/mm10_TSS.bed")
ground_truth_dir = os.path.join(project_dir, "ground_truth_files")
sample_input_dir = os.path.join(project_dir, "input/mESC/filtered_L2_E7.5_rep1")
output_dir = os.path.join(project_dir, "output/transformer_testing_output")

### Splitting the mm10 genome into ranges

Peak locations are going to be different for every sample. If I want to allow my method to work across multiple samples, I will need to split peaks into genomic ranges to allow the model to learn. If the peak overlaps with two genomic ranges, they will be counted as being located in the range which overlaps with the majority of the peak. If a peak is evenly split between two ranges, it will be randomly assigned.

#### Read in the mm10 gene TSS bed file

In [3]:
mm10_fasta_file = os.path.join(mm10_genome_dir, "chr1.fa")
mm10_chrom_sizes_file = os.path.join(mm10_genome_dir, "chrom.sizes")

In [48]:
print("Reading and formatting TSS bed file")
mm10_gene_tss_bed = pybedtools.BedTool(mm10_gene_tss_file)
gene_tss_df = (
    mm10_gene_tss_bed
    .filter(lambda x: x.chrom == "chr1")
    .saveas(os.path.join(mm10_genome_dir, "mm10_ch1_gene_tss.bed"))
    .to_dataframe()
    .sort_values(by="start", ascending=True)
    )
gene_tss_df.head()



Reading and formatting TSS bed file


Unnamed: 0,chrom,start,end,name,score,strand
0,chr1,3671498,3671498,Xkr4,.,-
1,chr1,4360303,4360303,Rp1,.,-
2,chr1,4360314,4360314,Rp1,.,-
3,chr1,4409241,4409241,Rp1,.,-
4,chr1,4497354,4497354,Sox17,.,-


#### Read in the scATAC-seq dataset

We will also need the ATAC-seq dataset that we will use for training. We will load in the scATAC-seq counts csv file.

In [4]:
mesc_atac_data = pd.read_parquet(os.path.join(sample_input_dir, "mESC_filtered_L2_E7.5_rep1_ATAC_processed.parquet")).set_index("peak_id")
mesc_atac_peak_loc = mesc_atac_data.index
mesc_atac_peak_loc = utils.format_peaks(mesc_atac_peak_loc)
mesc_atac_peak_loc = mesc_atac_peak_loc[mesc_atac_peak_loc["chromosome"] == "chr1"]
mesc_atac_peak_loc = mesc_atac_peak_loc.rename(columns={"chromosome":"chrom"})
mesc_atac_peak_loc.head()

Unnamed: 0,chrom,start,end,strand,peak_id
0,chr1,3035602,3036202,.,chr1:3035602-3036202
1,chr1,3062653,3063253,.,chr1:3062653-3063253
2,chr1,3072313,3072913,.,chr1:3072313-3072913
3,chr1,3191496,3192096,.,chr1:3191496-3192096
4,chr1,3340575,3341175,.,chr1:3340575-3341175


We will also restrict the scATAC-seq data to only use chromatin accessibility data for chromosome 1 for now.

In [113]:
mesc_atac_data_chr1 = mesc_atac_data[mesc_atac_data.index.isin(mesc_atac_peak_loc.peak_id)]
mesc_atac_data_chr1.iloc[:5, :3].head()

Unnamed: 0_level_0,E7.5_rep1#AAACCGGCAGAAATGC-1,E7.5_rep1#AAACGGATCATAACTG-1,E7.5_rep1#AAAGCACCATTAGCGC-1
peak_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
chr1:3035602-3036202,0.0,0.0,0.0
chr1:3062653-3063253,0.0,0.0,0.0
chr1:3072313-3072913,0.0,0.0,0.0
chr1:3191496-3192096,0.0,0.0,0.0
chr1:3340575-3341175,0.0,0.0,0.0


#### Read in the scRNA-seq dataset

In addition to the ATAC-seq dataset, we will also need the corresponding gene expression from the scRNA-seq counts csv file.

In [5]:
mesc_rna_data = pd.read_parquet(os.path.join(sample_input_dir, "mESC_filtered_L2_E7.5_rep1_RNA_processed.parquet")).set_index("gene_id")
mesc_rna_data.iloc[0:5, 0:3].head()

Unnamed: 0_level_0,E7.5_rep1#AAACCGGCAGAAATGC-1,E7.5_rep1#AAACGGATCATAACTG-1,E7.5_rep1#AAAGCACCATTAGCGC-1
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Xkr4,0.0,0.0,0.0
Sox17,0.0,0.0,0.0
Mrpl15,0.0,0.0,0.0
Lypla1,0.0,0.0,0.0
Tcea1,5.930213,0.0,0.0


### Genomic range distance to TG TSS

The distance between each genomic range and each potential TG TSS within 1MB of the range will be calculated using the exponential drop-off function:

$$\text{Scaling Factor} = e^{-\frac{\text{peak to TSS Distance}}{25000}}$$

This will reduce the regulatory effect that peaks in a genomic range can exert on a potential TG. The values in this matrix will be multiplied by the log1p normalized and min-max scaled RE accessibility.

We will tile the mm10 genome using the mm10 `chrom.sizes` file and a window size of 1 kb. We will use this to create our genomic ranges for mapping peaks to potential TGs.

In [114]:
window_size = 1000
mm10_genome_windows = pybedtools.bedtool.BedTool().window_maker(g=mm10_chrom_sizes_file, w=window_size)
mm10_chr1_windows = (
    mm10_genome_windows
    .filter(lambda x: x.chrom == "chr1")
    .saveas(os.path.join(mm10_genome_dir, f"mm10_chr1_windows_{window_size // 1000}kb.bed"))
    .to_dataframe()
    )
mm10_chr1_windows

Unnamed: 0,chrom,start,end
0,chr1,0,1000
1,chr1,1000,2000
2,chr1,2000,3000
3,chr1,3000,4000
4,chr1,4000,5000
...,...,...,...
195467,chr1,195467000,195468000
195468,chr1,195468000,195469000
195469,chr1,195469000,195470000
195470,chr1,195470000,195471000


#### Calculate the distance score between peaks and potential target genes

Only keep peak-TG rows where the distance between the peak and the gene TSS is less than 1 Mb.

Now that we have the ATAC peak locations and the gene locations, we can calculate the distance between peaks within 1 Mb of the gene TSS.

In [None]:
peak_bed = pybedtools.BedTool.from_dataframe(
    mesc_atac_peak_loc[["chrom", "start", "end", "peak_id"]]
    ).saveas(os.path.join(output_dir, "peak_tmp.bed"))

tss_bed = pybedtools.BedTool.from_dataframe(
    gene_tss_df[["chrom", "start", "end", "name"]]
    ).saveas(os.path.join(output_dir, "tss_tmp.bed"))

genes_near_peaks = utils.find_genes_near_peaks(peak_bed, tss_bed)
genes_near_peaks = genes_near_peaks[genes_near_peaks["TSS_dist"] <= 1e6]

  peak_chr  peak_start  peak_end               peak_id gene_chr  gene_start  \
0     chr1     3035602   3036202  chr1:3035602-3036202     chr1     3671498   
1     chr1     3062653   3063253  chr1:3062653-3063253     chr1     3671498   
2     chr1     3072313   3072913  chr1:3072313-3072913     chr1     3671498   
3     chr1     3191496   3192096  chr1:3191496-3192096     chr1     3671498   
4     chr1     3340575   3341175  chr1:3340575-3341175     chr1     3671498   

   gene_end target_id  
0   3671498      Xkr4  
1   3671498      Xkr4  
2   3671498      Xkr4  
3   3671498      Xkr4  
4   3671498      Xkr4  


Unnamed: 0,peak_chr,peak_start,peak_end,peak_id,gene_chr,gene_start,gene_end,target_id,TSS_dist
400,chr1,4496754,4497354,chr1:4496754-4497354,chr1,4497354,4497354,Sox17,0
87149,chr1,74542289,74542889,chr1:74542289-74542889,chr1,74542888,74542888,Plcd4,1
333683,chr1,190169693,190170293,chr1:190169693-190170293,chr1,190170295,190170295,Prox1os,2
50111,chr1,55363149,55363749,chr1:55363149-55363749,chr1,55363753,55363753,Boll,4
110386,chr1,84839235,84839835,chr1:84839235-84839835,chr1,84839840,84839840,Fbxo36,5


In [10]:
gene_tss_outfile = os.path.join(mm10_genome_dir, "mm10_ch1_gene_tss.bed")

gene_tss_df = pybedtools.BedTool(gene_tss_outfile).to_dataframe().sort_values(by="start", ascending=True)
gene_tss_df

Unnamed: 0,chrom,start,end,name,score,strand
0,chr19,3197703,3197703,1700030N03Rik,.,-
1,chr19,3283010,3283010,Ighmbp2,.,-
2,chr19,3283041,3283041,Mrpl21,.,+
3,chr19,3288919,3288919,Mir6984,.,+
4,chr19,3323300,3323300,Cpt1a,.,+
...,...,...,...,...,...,...
976,chr19,60874538,60874538,Prdx3,.,-
977,chr19,60889748,60889748,Grk5,.,+
978,chr19,61140840,61140840,Zfp950,.,-
979,chr19,61176309,61176309,Gm7102,.,-


In [None]:
chr19_genes = gene_tss_df["name"].str.capitalize().to_list()
chr19_genes

['1700030n03rik',
 'Ighmbp2',
 'Mrpl21',
 'Mir6984',
 'Cpt1a',
 'Tesmin',
 'Tesmin',
 'Gal',
 'Ppp6r3',
 'Lrp5',
 '1810055g02rik',
 '1810055g02rik',
 'Kmt5b',
 'Gm51271',
 'Chka',
 'Tcirg1',
 'Tcirg1',
 'Tcirg1',
 'Ndufs8',
 'Aldh3b1',
 'Unc93b1',
 'Aldh3b3',
 'Aldh3b2',
 'Acy3',
 'Tbx10',
 'Nudt8',
 'Doc2g',
 'Ndufv1',
 'Gstp1',
 'Gstp2',
 'Gstp3',
 'Cabp2',
 'Cabp2',
 'Cdk2ap2',
 'Pitpnm1',
 'Pitpnm1',
 'Aip',
 'Aip',
 'Tmem134',
 'Cabp4',
 'Gpr152',
 'Coro1b',
 'Ptprcap',
 'Rps6kb2',
 'Carns1',
 'Tbc1d10c',
 'Ppp1ca',
 'Rad9a',
 'Clcf1',
 'Clcf1',
 'Pold4',
 'Pold4',
 'Mir6985',
 'Ssh3',
 'Ssh3',
 'Ankrd13d',
 'Grk2',
 'Kdm2a',
 'A930001c03rik',
 'Rhod',
 'A930001c03rik',
 'Syt12',
 '2010003k11rik',
 'Pcx',
 'Pcx',
 'Lrfn4',
 'Mir6986',
 'Rce1',
 'Rce1',
 'Gm960',
 'Sptbn2',
 'Rbm4b',
 'Rbm4',
 'Rbm4',
 'Gm21992',
 'Gm21992',
 'Ccs',
 'Ccdc87',
 'Ctsf',
 'Actn3',
 'Zdhhc24',
 'Zdhhc24',
 'Bbs1',
 'Dpp3',
 'Dpp3',
 'Peli3',
 'Peli3',
 'Mrpl11',
 'Npas4',
 'Slc29a2',
 'B4gat1',
 'Brms

In [13]:
rna_df_chr19 = mesc_rna_data[mesc_rna_data.index.isin(chr19_genes)]
rna_df_chr19

Unnamed: 0_level_0,E7.5_rep1#AAACCGGCAGAAATGC-1,E7.5_rep1#AAACGGATCATAACTG-1,E7.5_rep1#AAAGCACCATTAGCGC-1,E7.5_rep1#AAAGCTTGTTGGTGAC-1,E7.5_rep1#AACATTGTCGAGGAGT-1,E7.5_rep1#AACCTCCTCGCTCACT-1,E7.5_rep1#AACCTTAAGCACAGAA-1,E7.5_rep1#AACGACAAGCGATACT-1,E7.5_rep1#AACTACTCAACCGCCA-1,E7.5_rep1#AACTACTCAGTTGCGT-1,...,E7.5_rep1#TTTAACGAGCACTAAC-1,E7.5_rep1#TTTAAGGTCCACCTTA-1,E7.5_rep1#TTTACGAAGCTATATG-1,E7.5_rep1#TTTCATCAGGGATGAC-1,E7.5_rep1#TTTCATCAGTTTGAGC-1,E7.5_rep1#TTTGAGTCAATTGAGA-1,E7.5_rep1#TTTGGCTGTCACCAAA-1,E7.5_rep1#TTTGTGGCATGAATAG-1,E7.5_rep1#TTTGTGTTCTGTGAGT-1,E7.5_rep1#TTTGTTGGTTAATGAC-1
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ighmbp2,0.000000,6.083672,0.000000,5.501346,0.000000,0.000000,5.480901,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,5.561457,0.000000,0.000000,0.000000,0.000000,0.000000
Mrpl21,0.000000,0.000000,0.000000,5.501346,0.000000,5.441462,0.000000,0.000000,5.228401,0.000000,...,0.000000,0.000000,5.389757,0.000000,0.000000,0.000000,0.000000,5.808821,0.000000,5.415941
Ppp6r3,5.930213,0.000000,6.298473,0.000000,0.000000,0.000000,0.000000,6.529411,5.918863,0.000000,...,5.430994,6.376187,6.080620,0.000000,0.000000,6.471371,0.000000,0.000000,5.458771,0.000000
Lrp5,0.000000,0.000000,0.000000,5.501346,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,6.080620,5.475677,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
Kmt5b,0.000000,0.000000,0.000000,5.501346,0.000000,0.000000,5.480901,0.000000,0.000000,6.141192,...,0.000000,0.000000,0.000000,5.475677,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Eif3a,7.027051,7.468255,6.298473,7.577210,5.825522,6.537182,5.480901,6.934389,7.015682,6.833262,...,7.219098,7.068483,6.485322,7.669173,7.350009,5.779770,7.113341,7.192862,7.064796,6.798896
Fam45a,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,5.561457,0.000000,0.000000,0.000000,0.000000,0.000000
Prdx3,0.000000,0.000000,5.607163,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
Grk5,5.930213,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,5.475677,0.000000,0.000000,0.000000,0.000000,0.000000,5.415941


In [45]:
rna_df_chr19_most_common_genes = rna_df_chr19.fillna(0).ne(0).sum(axis=1).sort_values(ascending=False)
top_50_expressed_genes = rna_df_chr19_most_common_genes.index[:50].to_numpy()
np.save("/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/output/transformer_testing_output/top_50_expressed_chr19_genes.npz", top_50_expressed_genes)

In [None]:
rna_df_chr19_gene_sum = rna_df_chr19.sum(axis=1).sort_values(ascending=False)
rna_df_chr19_gene_sum.index.to_list()[:100]

['Malat1',
 'Fau',
 'Eif3a',
 'Eef1g',
 'Cfl1',
 'Fth1',
 'Pgam1',
 'Snhg1',
 'Tnks2',
 'Hells',
 'Cox8a',
 'Ddb1',
 'Btaf1',
 'Scd2',
 'Tcf7l2',
 'Stip1',
 'Gldc',
 'Sf3b2',
 'Nolc1',
 'Atp5md',
 'Rtn3',
 'Ppp1r14b',
 'Uhrf2',
 'Smc3',
 'Smc5',
 'Oga',
 'Rbp4',
 'Sf1',
 'Rfx3',
 'Mark2',
 'Pten',
 'Banf1',
 'Gnaq',
 'Ganab',
 'Kdm2a',
 'Ppp1ca',
 'Pdcd11',
 'Tm9sf3',
 'Tmem132a',
 'Ppp6r3',
 'Tle4',
 'Tjp2',
 'Vti1a',
 'Atrnl1',
 'Incenp',
 'Kif11',
 'Lcor',
 'Pum3',
 'Chka',
 'Slc3a2',
 'Atad1',
 'Pprc1',
 'Mms19',
 'Zfp91',
 'Abhd17b',
 'Kif20b',
 'Carnmt1',
 'Zfand5',
 'Slf2',
 'Ndufb8',
 'Gbf1',
 'Gsto1',
 'Prpf19',
 'Cemip2',
 'Ldb1',
 'Pcgf5',
 'Dpp3',
 'Tmem258',
 'Npm3',
 'Lrp5',
 'Nxf1',
 'Btrc',
 'Ide',
 'Sgms1',
 'Vps13a',
 'Hnrnpul2',
 'Ubxn1',
 'Rcl1',
 'Shoc2',
 'Minpp1',
 'Cpsf7',
 'Kank1',
 'Ahnak',
 'Sh3pxd2a',
 'Trmt112',
 'Rrp12',
 'Psat1',
 'Chuk',
 'Smndc1',
 'Slk',
 'Znrd2',
 'Pdzd8',
 'Armh3',
 'Ehd1',
 'Kmt5b',
 'Mta2',
 'Grk2',
 'Ric1',
 'Pacs1',
 'Sart1']

Now that we have the distance between each peak and gene TSS, we will calculate an exponential dropoff score.

In [26]:
# Scale the TSS distance using an exponential drop-off function
# e^-dist/25000, same scaling function used in LINGER Cis-regulatory potential calculation
# https://github.com/Durenlab/LINGER
genes_near_peaks = genes_near_peaks.copy()
genes_near_peaks["TSS_dist_score"] = np.exp(-genes_near_peaks["TSS_dist"] / 250000)
genes_near_peaks.head()

Unnamed: 0,peak_chr,peak_start,peak_end,peak_id,gene_chr,gene_start,gene_end,target_id,TSS_dist,TSS_dist_score
400,chr1,4496754,4497354,chr1:4496754-4497354,chr1,4497354,4497354,Sox17,0,1.0
87149,chr1,74542289,74542889,chr1:74542289-74542889,chr1,74542888,74542888,Plcd4,1,0.999996
333683,chr1,190169693,190170293,chr1:190169693-190170293,chr1,190170295,190170295,Prox1os,2,0.999992
50111,chr1,55363149,55363749,chr1:55363149-55363749,chr1,55363753,55363753,Boll,4,0.999984
110386,chr1,84839235,84839835,chr1:84839235-84839835,chr1,84839840,84839840,Fbxo36,5,0.99998


### TF-RE Binding Potential

Homer will be used to calculate the ability for a TF to bind to each peak. Values for TF-RE edges where the TF is not predicted to bind will have a value of 0. We will map the TF-RE binding potential to the genomic ranges by taking the average TF binding potential for peaks within a genomic range. The TF-RE binding potential matrix will be multiplied by a log1p normalized and min-max scaled vector of TF expression.

Next, we need to format the peaks to use the Homer peak file format to find TFs matching to peaks.

#### Building the Homer peaks file

In [35]:
homer_peaks = genes_near_peaks[["peak_id", "peak_chr", "peak_start", "peak_end"]]
homer_peaks = homer_peaks.rename(columns={
    "peak_id":"PeakID", 
    "peak_chr":"chr",
    "peak_start":"start",
    "peak_end":"end"
    })
homer_peaks["strand"] = ["."] * len(homer_peaks)
homer_peaks["start"] = round(homer_peaks["start"].astype(int),0)
homer_peaks["end"] = round(homer_peaks["end"].astype(int),0)
homer_peaks = homer_peaks.drop_duplicates(subset="PeakID")

os.makedirs(os.path.join(output_dir, "tmp"), exist_ok=True)
homer_peak_path = os.path.join(output_dir, "tmp/homer_peaks.txt")
homer_peaks.to_csv(homer_peak_path, sep="\t", header=False, index=False)
homer_peaks.head()

Unnamed: 0,PeakID,chr,start,end,strand
400,chr1:4496754-4497354,chr1,4496754,4497354,.
87149,chr1:74542289-74542889,chr1,74542289,74542889,.
333683,chr1:190169693-190170293,chr1,190169693,190170293,.
50111,chr1:55363149-55363749,chr1,55363149,55363749,.
110386,chr1:84839235-84839835,chr1,84839235,84839835,.


In [36]:
print(len(homer_peaks))
print(len(homer_peaks.drop_duplicates(subset="PeakID")))

12940
12940


Next, we need to run homer on these peaks. I created a `run_homer.sh` script to handle running `findMotifsGenome.pl`, `annotatePeaks.pl`, and the pipeline script `homer_tf_peak_motifs.py`. `homer_tf_peak_motifs.py` calculates Homer TF to peak scores by counting the number of motifs found in each peak for a given TF in the output file from `annotatePeaks.pl`.

#### Running Homer

In [38]:
#!sbatch /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/dev/testing_scripts/run_homer.sh

Submitted batch job 3389915


Once Homer has finished running, we will load the results. We are interested in the number of binding sites for each TF to peak.

In [50]:
homer_results = pd.read_parquet(os.path.join(output_dir, "homer_tf_to_peak.parquet"), engine="pyarrow")
homer_results = homer_results.reset_index(drop=True)
homer_results["source_id"] = homer_results["source_id"].str.capitalize()
homer_results.head()

Unnamed: 0,peak_id,source_id,tf_motifs_in_peak,homer_binding_score
0,chr1:22805812-22806412,Amyb,1.0,0.000148
1,chr1:36251342-36251942,Amyb,1.0,0.000148
2,chr1:36021079-36021679,Amyb,1.0,0.000148
3,chr1:119082498-119083098,Amyb,2.0,0.000295
4,chr1:178684866-178685466,Amyb,1.0,0.000148


Let's review what we have so far:

- scRNA-seq gene expression of the TFs and TGs
- scATAC-seq chromatin accessibility of the peaks
- TSS start site of each gene
- A distance score between each peak and TG within 1 Mb of one another
- A TF-peak binding score from Homer
- Genomic bins of 1Kb

In [None]:
print(f"\nmesc_rna_data\n{mesc_rna_data.iloc[0:5, 0:1].head()}")
print(f"\nmesc_atac_data_chr1\n{mesc_atac_data_chr1.iloc[0:5, :1].head()}")
print(f"\ngene_tss_df\n{gene_tss_df.head()}")
print(f"\ngenes_near_peaks\n{genes_near_peaks.head()}")
print(f"\nhomer_results\n{homer_results.head()}")
print(f"\nmm10_chr1_windows\n{mm10_chr1_windows.head()}")

We only want genes that are either in the potential TG list or in the unique TF list and cells that are present in both the RNA and ATAC datasets.

In [115]:
atac_cell_barcodes = mesc_atac_data_chr1.columns.to_list()
rna_cell_barcodes = mesc_rna_data.columns.to_list()
atac_in_rna_shared_barcodes = [i for i in atac_cell_barcodes if i in rna_cell_barcodes]

# Make sure that the cell names are in the same order and in both datasets
shared_barcodes = sorted(set(atac_in_rna_shared_barcodes))

mesc_atac_data_chr1_shared = mesc_atac_data_chr1[shared_barcodes]
mesc_rna_data_shared = mesc_rna_data[shared_barcodes]

In [116]:
rna_first_cell = mesc_rna_data_shared.iloc[:, 0]
atac_first_cell = mesc_atac_data_chr1_shared.loc[:, rna_first_cell.name] # Makes sure the barcodes match
print(rna_first_cell.name, atac_first_cell.name)

E7.5_rep1#AAACCGGCAGAAATGC-1 E7.5_rep1#AAACCGGCAGAAATGC-1


In [117]:
potential_tgs = genes_near_peaks["target_id"].unique()
print(f"Number of potential TGs: {len(potential_tgs)}")

unique_tfs = homer_results["source_id"].unique()
print(f"Number of unique TFs: {len(unique_tfs)}")

unique_peaks = homer_results["peak_id"].unique()
print(f"Number of unique peaks: {len(unique_peaks)}")

Number of potential TGs: 1425
Number of unique TFs: 298
Number of unique peaks: 12940


The scores that we are interested in are the:

- `TSS_dist_score` between each peak and each potential TG from `genes_near_peaks`
- `tf_motifs_in_peak` between each TF and each peak from `homer_results`
- RNA expression values for TFs and potential TGs from `rna_first_cell`
- ATAC accessibility values for peaks from `atac_first_cell`

We want 
- Matrix 1: (RE accessibility * RE distance score) mapped to the genomic windows
- Matrix 2: (RE accessibility * RE distance score) x potential TG
- Matrix 3: (TF expression * Homer binding score) x Matrix 2


In [135]:
print("Peak to Potential TG Distance Score")
print(genes_near_peaks[["peak_id", "target_id", "TSS_dist_score"]].head())

print("\nHomer TF to Peak Binding Motifs")
print(homer_results[["source_id", "peak_id", "tf_motifs_in_peak"]].head())

print("\nscRNA-seq Gene Expression")
print(rna_first_cell.head())

print("\nscATAC-seq Peak Accessibility")
print(atac_first_cell.head())

Peak to Potential TG Distance Score
                         peak_id target_id  TSS_dist_score
400         chr1:4496754-4497354     Sox17        1.000000
87149     chr1:74542289-74542889     Plcd4        0.999996
333683  chr1:190169693-190170293   Prox1os        0.999992
50111     chr1:55363149-55363749      Boll        0.999984
110386    chr1:84839235-84839835    Fbxo36        0.999980

Homer TF to Peak Binding Motifs
  source_id                   peak_id  tf_motifs_in_peak
0      Amyb    chr1:22805812-22806412                1.0
1      Amyb    chr1:36251342-36251942                1.0
2      Amyb    chr1:36021079-36021679                1.0
3      Amyb  chr1:119082498-119083098                2.0
4      Amyb  chr1:178684866-178685466                1.0

scRNA-seq Gene Expression
gene_id
Xkr4      0.000000
Sox17     0.000000
Mrpl15    0.000000
Lypla1    0.000000
Tcea1     5.930213
Name: E7.5_rep1#AAACCGGCAGAAATGC-1, dtype: float64

scATAC-seq Peak Accessibility
peak_id
chr1:3035602-30

### Combining TF-RE binding potential with RE regulatory potential values

The (peak accessibility * peak distance) x TG and (TF expression * TF-peak binding potential) x TG matrices will be matrix multiplied to get the final TF x peak x TG matrix.


#### TF-peak Binding Potential

We calculate the TF-peak binding potential as the TF to peak binding score from homer * the gene expression of each TF

In [None]:
tf_peak_binding_potential = pd.merge(homer_results, rna_first_cell, left_on="source_id", right_index=True, how="inner")
tf_peak_binding_potential["tf_peak_binding_potential"] = tf_peak_binding_potential["homer_binding_score"] * tf_peak_binding_potential.iloc[:,-1]
tf_peak_binding_potential = tf_peak_binding_potential[["source_id", "peak_id", "tf_peak_binding_potential"]]
tf_peak_binding_potential.head()

Unnamed: 0,source_id,peak_id,tf_peak_binding_potential
7792,Atf4,chr1:33026581-33027181,0.0
7793,Atf4,chr1:107878563-107879163,0.0
7794,Atf4,chr1:37698121-37698721,0.0
7795,Atf4,chr1:23263694-23264294,0.0
7796,Atf4,chr1:72867374-72867974,0.0


#### Peak-TG Regulatory Potential

We calculate the peak-TG regulatory potential as the peak accessibility * the peak to potential TG TSS distance score

In [120]:
peak_tg_regulatory_potential = pd.merge(genes_near_peaks, atac_first_cell, left_on="peak_id", right_index=True, how="inner")
peak_tg_regulatory_potential["peak_tg_regulatory_potential"] = peak_tg_regulatory_potential["TSS_dist_score"] * peak_tg_regulatory_potential.iloc[:, -1]
peak_tg_regulatory_potential = peak_tg_regulatory_potential[["peak_id", "target_id", "peak_tg_regulatory_potential"]]
peak_tg_regulatory_potential.head()

Unnamed: 0,peak_id,target_id,peak_tg_regulatory_potential
400,chr1:4496754-4497354,Sox17,5.171927
87149,chr1:74542289-74542889,Plcd4,0.0
333683,chr1:190169693-190170293,Prox1os,0.0
50111,chr1:55363149-55363749,Boll,0.0
110386,chr1:84839235-84839835,Fbxo36,0.0


#### TF-Peak-TG Regulatory Potential

Finally, we get the TF-Peak-TG regulatory potential by multiplying the TF-peak binding potential scores with the peak-TG regulatory potential scores for each shared peak.

In [121]:
tf_peak_tg_regulatory_potential = pd.merge(tf_peak_binding_potential, peak_tg_regulatory_potential, on="peak_id", how="outer")
tf_peak_tg_regulatory_potential["tf_peak_tg_score"] = tf_peak_tg_regulatory_potential["tf_peak_binding_potential"] * tf_peak_tg_regulatory_potential["peak_tg_regulatory_potential"]
tf_peak_tg_regulatory_potential = tf_peak_tg_regulatory_potential[["source_id", "peak_id", "target_id", "tf_peak_tg_score"]]
print(tf_peak_tg_regulatory_potential.head())

  source_id                 peak_id target_id  tf_peak_tg_score
0     Hnf4a  chr1:10007124-10007724   Ppp1r42               0.0
1     Hnf4a  chr1:10007124-10007724     Cops5               0.0
2     Hnf4a  chr1:10007124-10007724     Cspp1               0.0
3     Hnf4a  chr1:10007124-10007724     Cspp1               0.0
4     Hnf4a  chr1:10007124-10007724     Tcf24               0.0


### Aggregate the TF-peak-TG scores into genomic coordinate windows

We will not aggregate the peaks based on the 1 Kb windows that we created earlier. The peak scores for a window will be summed to get a final score. By aggregating the peaks into windows with a static size, we ensure that the transformer architecture will work with any mm10 data.

In [122]:
# Parse peak_id into genomic coords (chrom, start, end)
coords = tf_peak_tg_regulatory_potential["peak_id"].str.extract(
    r"(?P<chrom>[^:]+):(?P<start>\d+)-(?P<end>\d+)"
).astype({"start":"int64","end":"int64"})

df = pd.concat([tf_peak_tg_regulatory_potential, coords], axis=1)

df = df[df["chrom"] == "chr1"].copy()

window_size = int((mm10_chr1_windows["end"] - mm10_chr1_windows["start"]).mode().iloc[0])

# Build a quick lookup of window_id strings from window indices
# window index k -> [start=k*w, end=(k+1)*w)
win_lut = {}
for _, row in mm10_chr1_windows.iterrows():
    k = row["start"] // window_size
    win_lut[k] = f'{row["chrom"]}:{row["start"]}-{row["end"]}'

# --- Assign each unique peak to the window with maximal overlap (random ties) ---
rng = np.random.default_rng(0)  # set a seed for reproducibility; change/remove if you want different random choices

peaks_unique = (
    df.loc[:, ["peak_id", "chrom", "start", "end"]]
      .drop_duplicates(subset=["peak_id"])
      .reset_index(drop=True)
)

def assign_best_window(start, end, w):
    # windows indices spanned by the peak (inclusive)
    i0 = start // w
    i1 = (end - 1) // w  # subtract 1 so exact boundary end==k*w goes to k-1 window
    if i1 < i0:
        i1 = i0
    # compute overlaps with all spanned windows
    overlaps = []
    for k in range(i0, i1 + 1):
        bin_start = k * w
        bin_end = bin_start + w
        ov = max(0, min(end, bin_end) - max(start, bin_start))
        overlaps.append((k, ov))
    # choose the k with max overlap; break ties randomly
    ov_vals = [ov for _, ov in overlaps]
    max_ov = max(ov_vals)
    candidates = [k for (k, ov) in overlaps if ov == max_ov]
    if len(candidates) == 1:
        return candidates[0]
    else:
        return rng.choice(candidates)

peak_to_window_idx = peaks_unique.apply(
    lambda r: assign_best_window(r["start"], r["end"], window_size), axis=1
)
peaks_unique["window_idx"] = peak_to_window_idx
peaks_unique["window_id"] = peaks_unique["window_idx"].map(win_lut)

# Map window assignment back to the full TF–peak–gene table and aggregate
df = df.merge(
    peaks_unique.loc[:, ["peak_id", "window_id"]],
    on="peak_id",
    how="left"
)

# Aggregate scores per TF × window × gene
binned_scores = (
    df.groupby(["source_id", "window_id", "target_id"], observed=True)["tf_peak_tg_score"]
      .sum()
      .reset_index()
).rename(columns={"tf_peak_tg_score":"tf_window_tg_score"})

print(binned_scores.head())


  source_id               window_id      target_id  tf_window_tg_score
0      Atf1  chr1:10009000-10010000  1700034P13Rik                 0.0
1      Atf1  chr1:10009000-10010000  2610203C22Rik                 0.0
2      Atf1  chr1:10009000-10010000         Adhfe1                 0.0
3      Atf1  chr1:10009000-10010000        Arfgef1                 0.0
4      Atf1  chr1:10009000-10010000          Cops5                 0.0


Next, we need to pivot the long dataframe into a 3D TF x window x TG NumPy array.

In [123]:
# Get unique IDs
tfs = binned_scores["source_id"].unique()
windows = binned_scores["window_id"].unique()
genes = binned_scores["target_id"].unique()

# Create index maps
tf_idx = {tf: i for i, tf in enumerate(tfs)}
window_idx = {p: i for i, p in enumerate(windows)}
gene_idx = {g: i for i, g in enumerate(genes)}

# Initialize 3D matrix
tensor_np = np.zeros((len(tfs), len(windows), len(genes)), dtype=float)

# Fill values
for _, row in binned_scores.iterrows():
    i = tf_idx[row["source_id"]]
    j = window_idx[row["window_id"]]
    k = gene_idx[row["target_id"]]
    tensor_np[i, j, k] = row["tf_window_tg_score"]

print(tensor_np.shape)  # (n_TFs, n_windows, n_genes)

(106, 11675, 1415)


Now that we have the input matrix, we will store it in case the kernel crashes

In [124]:
np.savez_compressed(
    os.path.join(output_dir, "tf_window_gene_tensor.npz"),
    tensor=tensor_np,
    tfs=tfs,
    windows=windows,
    genes=genes
)

In [None]:
print(f"Number of TFs: {tensor_np.shape[0]}")
print(f"Number of Windows: {tensor_np.shape[1]}")
print(f"Number of TGs: {tensor_np.shape[2]}")

Number of TFs: 106
Number of Windows: 12480
Number of TGs: 1425


To recap, the scores in this array represnts the TF-peak regulatory potentials for each TG based on the Homer TF-peak binding scores, the peak-TG regulatory potential scores, the peak accessibility values, and the TF RNA expression values

---

## Transformer Architecture

We will use the genomic windows as the sequence and let each gene query the windows for evidence to predict its gene expression. The tokens will be the 12,480 windows for Chr1. The features for the tokens will be set using a learned linear projection of the TF x TG axis.

### Data Normalization and Range Clamping

We convert the numpy array into a PyTorch tensor, then standardize the distribution of the data using Z-score normalization and clamp the min and max values between -5 and 5

In [125]:
X = torch.from_numpy(tensor_np).float()     # 106 TF x 12,480 windows x 1,425 TGs

# Normalize the window data per TF across the TGs
X = (X - X.mean(dim=2, keepdim=True)) / (X.std(dim=2, keepdim=True) + 1e-6)
X = torch.clamp(X, -5, 5)

### Linear projecting of the TF and TG features

Running a transformer model on a 106 x 12480 x 1425 matrix will be too computationally heavy. Instead, we will use trainable linear projections to reduce the dimenisonality of the TF and TG features down to a 256-dimension embedding.

In [126]:
# We will use a dimension 256 features per window
TF, W, TG = X.shape
d_model = 256

# Project the TG dimension down to 64 using a linear projection
tg_channels = 64 
proj_tg = nn.Linear(TG, tg_channels, bias=False)
Xg = proj_tg(X)             # TF, W, 64

# Project the TF dimension down to 64 using a linear projection
tf_channels = 32
proj_tf = nn.Linear(TF, tf_channels, bias=False)

# Need to permute Xg to get the TFs as the last dimension before projecting
Xg = Xg.permute(1, 2, 0)    # W, 64, TF
Xg = proj_tf(Xg)            # W, 64, 32

# Next, we flatten the 64 tg_channels x 32 tf_channels
window_features = Xg.reshape(W, tg_channels * tf_channels)

# Now we project the window features to the shape of d_model
proj_window = nn.Linear(tg_channels * tf_channels, d_model)
tokens = proj_window(window_features)   # W, 256


In [133]:
tokens.shape

torch.Size([1, 11675, 256])

### Downsampling the Windows

A window size of 12480 is still too large to run, given that the computational resources required to run attention scale exponentially with the dimensionality of the window size. We will pool the tokens using a kernel size of 8 and a stride of 8. This bins and averages the values to reduce the dimensionality.

In [127]:
# Downsample by average pooling across windows
pool = nn.AvgPool1d(kernel_size=8, stride=8)  # along sequence length, bins and pools the data
tokens = tokens.unsqueeze(0)
tokens_ds = pool(tokens.transpose(1,2)).transpose(1,2)   # [1, W', d_model]
W_ds = tokens_ds.size(1)

In [134]:
tokens_ds.shape

torch.Size([1, 1459, 256])

### Building the transformer encoder layer

Next, we set up a transformer with 8 heads and a feedforward dimension of 512.

In [128]:
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model, nhead=8, dim_feedforward=512, batch_first=True
)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)

H = encoder(tokens_ds)   # [1, W', d_model]

### Positional Encoding (Not used for now)

Next, we use a positional encoder to encode the positions of the genomic windows. This will allow the model to learn context about whether a gene should be expressed using information about the regulatory potential landscape around it.

We will use a local relative position bias (RPB) to add a learnable bias for windows within 10 Kb of each gene. This will help the attention mechanism to learn how much it should bias closer windows compared to further windows. 

In [None]:
class RelativePositionBias(nn.Module):
    """
    Learned relative position bias for self-attention over windows.
    - max_kb: e.g., 10 (±10 kb neighborhood)
    - window_size_bp: e.g., 1000 for 1kb windows
    - n_heads: attention heads to broadcast the bias over
    """
    def __init__(self, n_heads: int, window_size_bp: int, max_kb: int):
        super().__init__()
        self.n_heads = n_heads
        self.window_size_bp = window_size_bp
        self.max_offset = max(1, (max_kb * 1000) // window_size_bp)  # in window units
        n_buckets = 2 * self.max_offset + 1  # offsets from -max..+max
        self.bias = nn.Parameter(torch.zeros(n_heads, n_buckets))     # [H, 2*max+1]
        nn.init.zeros_(self.bias)

    def forward(self, seq_len: int, device=None):
        """
        Returns an additive bias tensor for attention scores:
        shape [H, L, L], to be added to the (H, L, L) logits.
        """
        device = device or self.bias.device
        # offsets[i,j] = j - i  in window units
        idx = torch.arange(seq_len, device=device)
        off = idx[None, :] - idx[:, None]   # [L, L] in windows

        # clip to [-max_offset, +max_offset]
        off = off.clamp(-self.max_offset, self.max_offset)
        # map offset -> bucket index [0..2*max_offset], center at 0 offset
        buckets = (off + self.max_offset).long()  # [L, L]

        # gather per-head bias
        # bias: [H, B], buckets: [L, L] -> out [H, L, L]
        out = self.bias[:, buckets]  # advanced indexing
        return out  # [H, L, L]

### Cross-attention from the gene queries to the window tokens

Each TG has a learned Query vector, and the window tokens each have Key, Value vectors. This allows the model to learn to predict the TG expression using context about the values of the windows around it.

#### Set up the multi-headed attention

In [129]:
n_genes = len(genes)  # 1425
gene_embed = nn.Embedding(n_genes, d_model)

cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, batch_first=True)
readout = nn.Sequential(
    nn.LayerNorm(d_model),
    nn.Linear(d_model, 1)
)

# Build gene queries (index 0..n_genes-1)
gene_ids = torch.arange(n_genes)
GQ = gene_embed(gene_ids).unsqueeze(0)  # [1, n_genes, d_model]

# Cross-attention: (Q=genes, K/V=windows)
Z, _ = cross_attn(query=GQ, key=H, value=H)  # [1, n_genes, d_model]

pred_expr = readout(Z).squeeze(-1)           # [1, n_genes]

#### Register Parameters

In [130]:
params = (
    list(proj_tg.parameters()) +
    list(proj_tf.parameters()) +
    list(proj_window.parameters()) +
    list(encoder.parameters()) +
    list(gene_embed.parameters()) +
    list(cross_attn.parameters()) +
    list(readout.parameters())
    # + list(rpb.parameters())  # if you use the custom RPB blocks
)
opt = torch.optim.AdamW(params, lr=1e-3, weight_decay=1e-4)

#### Move everything to the same CUDA device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X = X.to(device)
proj_tg, proj_tf, proj_window, encoder = proj_tg.to(device), proj_tf.to(device), proj_window.to(device), encoder.to(device)
gene_embed, cross_attn, readout = gene_embed.to(device), cross_attn.to(device), readout.to(device)
tokens_ds = tokens_ds.to(device)

#### Get the true TG expression values to calculate error

The TG expression vector needs to be in the same order and have the same length as the TG predictions (genes dictionary)

In [131]:
dup = pd.Index(genes).duplicated()
assert not dup.any(), f"Duplicate gene IDs in prediction axis at: {np.where(dup)[0][:10]}"

# Align counts to prediction order from the gene to index mapping (same length and order as genes)
true_counts = rna_first_cell.reindex(genes)

# build mask for missing genes (not present in RNA)
mask = ~true_counts.isna().to_numpy()

# Handle missing genes using a masked loss 
y_true_vec = true_counts.to_numpy(dtype=float)        # shape (n_genes,)

y_true = torch.tensor(y_true_vec, dtype=torch.float32).unsqueeze(0)   # [1, n_genes]
mask_t = torch.tensor(mask, dtype=torch.bool).unsqueeze(0)            # [1, n_genes]

### Loss Calculation

We will calculate the MSE between the true expression values of the TGs and the predicted TG expression values.

In [132]:
def masked_mse(pred, y, m):
    diff2 = (pred - y)**2
    return diff2[m].mean()

# pred_expr: [1, n_genes] from the model
loss = masked_mse(pred_expr, y_true, mask_t)
print(loss)

tensor(7.1506, grad_fn=<MeanBackward0>)
