# Transformer Model

My idea here is to use the TF expression, RE availability, and the distance from RE to TG TSS as the input to a transformer model which predicts TG expression. The output of the transformer will be the predicted expression of each TG, which will be compared to the expression in the dataset.

In [1]:
!hostnamectl

   Static hostname: psh01com1hcom35
         Icon name: computer-server
           Chassis: server
        Machine ID: 6860da98c8574f44be8f2ea25abdb7fb
           Boot ID: bbdde321ce964aff96f74d2bfdfe077c
  Operating System: ]8;;https://www.redhat.com/Red Hat Enterprise Linux 8.10 (Ootpa)]8;;
       CPE OS Name: cpe:/o:redhat:enterprise_linux:8::baseos
            Kernel: Linux 4.18.0-553.22.1.el8_10.x86_64
      Architecture: x86-64


In [2]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import pybedtools
from grn_inference import utils

torch.manual_seed(1)
np.random.seed(42)

project_dir = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
mm10_genome_dir = os.path.join(project_dir, "data/reference_genome/mm10")
mm10_gene_tss_file = os.path.join(project_dir, "data/genome_annotation/mm10/mm10_TSS.bed")
ground_truth_dir = os.path.join(project_dir, "ground_truth_files")
sample_input_dir = os.path.join(project_dir, "input/mESC/filtered_L2_E7.5_rep1")
output_dir = os.path.join(project_dir, "output/transformer_testing_output")

### Splitting the mm10 genome into ranges

Peak locations are going to be different for every sample. If I want to allow my method to work across multiple samples, I will need to split peaks into genomic ranges to allow the model to learn. If the peak overlaps with two genomic ranges, they will be counted as being located in the range which overlaps with the majority of the peak. If a peak is evenly split between two ranges, it will be randomly assigned.

#### Read in the mm10 gene TSS bed file

In [60]:
mm10_fasta_file = os.path.join(mm10_genome_dir, "chr19.fa")
mm10_chrom_sizes_file = os.path.join(mm10_genome_dir, "chrom.sizes")

In [61]:
print("Reading and formatting TSS bed file")
mm10_gene_tss_bed = pybedtools.BedTool(mm10_gene_tss_file)
gene_tss_df = (
    mm10_gene_tss_bed
    .filter(lambda x: x.chrom == "chr19")
    .saveas(os.path.join(mm10_genome_dir, "mm10_ch19_gene_tss.bed"))
    .to_dataframe()
    .sort_values(by="start", ascending=True)
    )
gene_tss_df.head()



Reading and formatting TSS bed file


Unnamed: 0,chrom,start,end,name,score,strand
0,chr19,3197703,3197703,1700030N03Rik,.,-
1,chr19,3283010,3283010,Ighmbp2,.,-
2,chr19,3283041,3283041,Mrpl21,.,+
3,chr19,3288919,3288919,Mir6984,.,+
4,chr19,3323300,3323300,Cpt1a,.,+


#### Read in the scATAC-seq dataset

We will also need the ATAC-seq dataset that we will use for training. We will load in the scATAC-seq counts csv file.

In [62]:
mesc_atac_data = pd.read_parquet(os.path.join(sample_input_dir, "mESC_filtered_L2_E7.5_rep1_ATAC_processed.parquet")).set_index("peak_id")
mesc_atac_peak_loc = mesc_atac_data.index
mesc_atac_peak_loc = utils.format_peaks(mesc_atac_peak_loc)
mesc_atac_peak_loc = mesc_atac_peak_loc[mesc_atac_peak_loc["chromosome"] == "chr19"]
mesc_atac_peak_loc = mesc_atac_peak_loc.rename(columns={"chromosome":"chrom"})
mesc_atac_peak_loc.head()

Unnamed: 0,chrom,start,end,strand,peak_id
84361,chr19,3282721,3283321,.,chr19:3282721-3283321
84362,chr19,3291918,3292518,.,chr19:3291918-3292518
84363,chr19,3306995,3307595,.,chr19:3306995-3307595
84364,chr19,3308573,3309173,.,chr19:3308573-3309173
84365,chr19,3321460,3322060,.,chr19:3321460-3322060


We will also restrict the scATAC-seq data to only use chromatin accessibility data for chromosome 19 for now.

In [63]:
mesc_atac_data_chr19 = mesc_atac_data[mesc_atac_data.index.isin(mesc_atac_peak_loc.peak_id)]
mesc_atac_data_chr19.iloc[:5, :3].head()

Unnamed: 0_level_0,E7.5_rep1#AAACCGGCAGAAATGC-1,E7.5_rep1#AAACGGATCATAACTG-1,E7.5_rep1#AAAGCACCATTAGCGC-1
peak_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
chr19:3282721-3283321,0.0,3.43104,0.0
chr19:3291918-3292518,0.0,0.0,0.0
chr19:3306995-3307595,0.0,0.0,0.0
chr19:3308573-3309173,0.0,0.0,0.0
chr19:3321460-3322060,0.0,0.0,0.0


#### Read in the scRNA-seq dataset

In addition to the ATAC-seq dataset, we will also need the corresponding gene expression from the scRNA-seq counts csv file.

In [64]:
mesc_rna_data = pd.read_parquet(os.path.join(sample_input_dir, "mESC_filtered_L2_E7.5_rep1_RNA_processed.parquet")).set_index("gene_id")
mesc_rna_data.iloc[0:5, 0:3].head()

Unnamed: 0_level_0,E7.5_rep1#AAACCGGCAGAAATGC-1,E7.5_rep1#AAACGGATCATAACTG-1,E7.5_rep1#AAAGCACCATTAGCGC-1
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Xkr4,0.0,0.0,0.0
Sox17,0.0,0.0,0.0
Mrpl15,0.0,0.0,0.0
Lypla1,0.0,0.0,0.0
Tcea1,5.930213,0.0,0.0


### Genomic range distance to TG TSS

The distance between each genomic range and each potential TG TSS within 1MB of the range will be calculated using the exponential drop-off function:

$$\text{Scaling Factor} = e^{-\frac{\text{peak to TSS Distance}}{25000}}$$

This will reduce the regulatory effect that peaks in a genomic range can exert on a potential TG. The values in this matrix will be multiplied by the log1p normalized and min-max scaled RE accessibility.

We will tile the mm10 genome using the mm10 `chrom.sizes` file and a window size of 1 kb. We will use this to create our genomic ranges for mapping peaks to potential TGs.

In [65]:
window_size = 1000
mm10_genome_windows = pybedtools.bedtool.BedTool().window_maker(g=mm10_chrom_sizes_file, w=window_size)
mm10_chr19_windows = (
    mm10_genome_windows
    .filter(lambda x: x.chrom == "chr19")
    .saveas(os.path.join(mm10_genome_dir, f"mm10_chr1_windows_{window_size // 1000}kb.bed"))
    .to_dataframe()
    )
mm10_chr19_windows

Unnamed: 0,chrom,start,end
0,chr19,0,1000
1,chr19,1000,2000
2,chr19,2000,3000
3,chr19,3000,4000
4,chr19,4000,5000
...,...,...,...
61427,chr19,61427000,61428000
61428,chr19,61428000,61429000
61429,chr19,61429000,61430000
61430,chr19,61430000,61431000


#### Calculate the distance score between peaks and potential target genes

Only keep peak-TG rows where the distance between the peak and the gene TSS is less than 1 Mb.

Now that we have the ATAC peak locations and the gene locations, we can calculate the distance between peaks within 1 Mb of the gene TSS.

In [66]:
peak_bed = pybedtools.BedTool.from_dataframe(
    mesc_atac_peak_loc[["chrom", "start", "end", "peak_id"]]
    ).saveas(os.path.join(output_dir, "peak_tmp.bed"))

tss_bed = pybedtools.BedTool.from_dataframe(
    gene_tss_df[["chrom", "start", "end", "name"]]
    ).saveas(os.path.join(output_dir, "tss_tmp.bed"))

genes_near_peaks = utils.find_genes_near_peaks(peak_bed, tss_bed)
genes_near_peaks = genes_near_peaks[genes_near_peaks["TSS_dist"] <= 1e6]

Now that we have the distance between each peak and gene TSS, we will calculate an exponential dropoff score.

In [67]:
# Scale the TSS distance using an exponential drop-off function
# e^-dist/25000, same scaling function used in LINGER Cis-regulatory potential calculation
# https://github.com/Durenlab/LINGER
genes_near_peaks = genes_near_peaks.copy()
genes_near_peaks["TSS_dist_score"] = np.exp(-genes_near_peaks["TSS_dist"] / 250000)
genes_near_peaks.head()

Unnamed: 0,peak_chr,peak_start,peak_end,peak_id,gene_chr,gene_start,gene_end,target_id,TSS_dist,TSS_dist_score
146647,chr19,32387451,32388051,chr19:32387451-32388051,chr19,32388049,32388049,Sgms1,2,0.999992
188509,chr19,44756302,44756902,chr19:44756302-44756902,chr19,44756905,44756905,Pax2,3,0.999988
176071,chr19,42779374,42779974,chr19:42779374-42779974,chr19,42779978,42779978,Hps1,4,0.999984
80777,chr19,6941265,6941865,chr19:6941265-6941865,chr19,6941860,6941860,Bad,5,0.99998
224759,chr19,53142146,53142746,chr19:53142146-53142746,chr19,53142755,53142755,Add3,9,0.999964


### TF-RE Binding Potential

Homer will be used to calculate the ability for a TF to bind to each peak. Values for TF-RE edges where the TF is not predicted to bind will have a value of 0. We will map the TF-RE binding potential to the genomic ranges by taking the average TF binding potential for peaks within a genomic range. The TF-RE binding potential matrix will be multiplied by a log1p normalized and min-max scaled vector of TF expression.

Next, we need to format the peaks to use the Homer peak file format to find TFs matching to peaks.

#### Building the Homer peaks file

In [68]:
homer_peaks = genes_near_peaks[["peak_id", "peak_chr", "peak_start", "peak_end"]]
homer_peaks = homer_peaks.rename(columns={
    "peak_id":"PeakID", 
    "peak_chr":"chr",
    "peak_start":"start",
    "peak_end":"end"
    })
homer_peaks["strand"] = ["."] * len(homer_peaks)
homer_peaks["start"] = round(homer_peaks["start"].astype(int),0)
homer_peaks["end"] = round(homer_peaks["end"].astype(int),0)
homer_peaks = homer_peaks.drop_duplicates(subset="PeakID")

os.makedirs(os.path.join(output_dir, "tmp"), exist_ok=True)
homer_peak_path = os.path.join(output_dir, "tmp/homer_peaks.txt")
homer_peaks.to_csv(homer_peak_path, sep="\t", header=False, index=False)
homer_peaks.head()

Unnamed: 0,PeakID,chr,start,end,strand
146647,chr19:32387451-32388051,chr19,32387451,32388051,.
188509,chr19:44756302-44756902,chr19,44756302,44756902,.
176071,chr19:42779374-42779974,chr19,42779374,42779974,.
80777,chr19:6941265-6941865,chr19,6941265,6941865,.
224759,chr19:53142146-53142746,chr19,53142146,53142746,.


In [69]:
print(len(homer_peaks))
print(len(homer_peaks.drop_duplicates(subset="PeakID")))

5585
5585


Next, we need to run homer on these peaks. I created a `run_homer.sh` script to handle running `findMotifsGenome.pl`, `annotatePeaks.pl`, and the pipeline script `homer_tf_peak_motifs.py`. `homer_tf_peak_motifs.py` calculates Homer TF to peak scores by counting the number of motifs found in each peak for a given TF in the output file from `annotatePeaks.pl`.

#### Running Homer

In [72]:
!sbatch /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/dev/testing_scripts/run_homer.sh

Submitted batch job 3390268


Once Homer has finished running, we will load the results. We are interested in the number of binding sites for each TF to peak.

In [93]:
homer_results = pd.read_parquet(os.path.join(output_dir, "homer_tf_to_peak.parquet"), engine="pyarrow")
homer_results = homer_results.reset_index(drop=True)
homer_results["source_id"] = homer_results["source_id"].str.capitalize()
homer_results.head()

Unnamed: 0,peak_id,source_id,tf_motifs_in_peak,homer_binding_score
0,chr19:11807651-11808251,Amyb,1.0,0.000325
1,chr19:19058520-19059120,Amyb,1.0,0.000325
2,chr19:46528570-46529170,Amyb,1.0,0.000325
3,chr19:41666775-41667375,Amyb,1.0,0.000325
4,chr19:10719611-10720211,Amyb,2.0,0.00065


Let's review what we have so far:

- scRNA-seq gene expression of the TFs and TGs
- scATAC-seq chromatin accessibility of the peaks
- TSS start site of each gene
- A distance score between each peak and TG within 1 Mb of one another
- A TF-peak binding score from Homer
- Genomic bins of 1Kb

In [94]:
print(f"\nmesc_rna_data\n{mesc_rna_data.iloc[0:5, 0:1].head()}")
print(f"\nmesc_atac_data_chr1\n{mesc_atac_data_chr19.iloc[0:5, :1].head()}")
print(f"\ngene_tss_df\n{gene_tss_df.head()}")
print(f"\ngenes_near_peaks\n{genes_near_peaks.head()}")
print(f"\nhomer_results\n{homer_results.head()}")
print(f"\nmm10_chr1_windows\n{mm10_chr19_windows.head()}")


mesc_rna_data
         E7.5_rep1#AAACCGGCAGAAATGC-1
gene_id                              
Xkr4                         0.000000
Sox17                        0.000000
Mrpl15                       0.000000
Lypla1                       0.000000
Tcea1                        5.930213

mesc_atac_data_chr1
                       E7.5_rep1#AAACCGGCAGAAATGC-1
peak_id                                            
chr19:3282721-3283321                           0.0
chr19:3291918-3292518                           0.0
chr19:3306995-3307595                           0.0
chr19:3308573-3309173                           0.0
chr19:3321460-3322060                           0.0

gene_tss_df
   chrom    start      end           name score strand
0  chr19  3197703  3197703  1700030N03Rik     .      -
1  chr19  3283010  3283010        Ighmbp2     .      -
2  chr19  3283041  3283041         Mrpl21     .      +
3  chr19  3288919  3288919        Mir6984     .      +
4  chr19  3323300  3323300          Cpt1a     

In [112]:
potential_tgs = genes_near_peaks["target_id"].unique()
print(f"Number of potential TGs: {len(potential_tgs)}")

unique_tfs = homer_results["source_id"].unique()
print(f"Number of unique TFs: {len(unique_tfs)}")

unique_peaks = homer_results["peak_id"].unique()
print(f"Number of unique peaks: {len(unique_peaks)}")

Number of potential TGs: 809
Number of unique TFs: 262
Number of unique peaks: 5585


The scores that we are interested in are the:

- `TSS_dist_score` between each peak and each potential TG from `genes_near_peaks`
- `tf_motifs_in_peak` between each TF and each peak from `homer_results`
- RNA expression values for TFs and potential TGs from `rna_first_cell`
- ATAC accessibility values for peaks from `atac_first_cell`

We want 
- Matrix 1: (RE accessibility * RE distance score) mapped to the genomic windows
- Matrix 2: (RE accessibility * RE distance score) x potential TG
- Matrix 3: (TF expression * Homer binding score) x Matrix 2


### Combining TF-RE binding potential with RE regulatory potential values

The (peak accessibility * peak distance) x TG and (TF expression * TF-peak binding potential) x TG matrices will be matrix multiplied to get the final TF x peak x TG matrix.


We need to get a unique list of each window and TG to create a standard index mapping of the mouse chromosome 19

In [99]:
chr19_gene_dict = tss_bed.to_dataframe()["name"].str.capitalize().to_dict()
print(chr19_gene_dict)
print(len(chr19_gene_dict))
chr19_windows = mm10_chr19_windows
print(chr19_windows)

{0: '1700030n03rik', 1: 'Ighmbp2', 2: 'Mrpl21', 3: 'Mir6984', 4: 'Cpt1a', 5: 'Tesmin', 6: 'Tesmin', 7: 'Gal', 8: 'Ppp6r3', 9: 'Lrp5', 10: '1810055g02rik', 11: '1810055g02rik', 12: 'Kmt5b', 13: 'Gm51271', 14: 'Chka', 15: 'Tcirg1', 16: 'Tcirg1', 17: 'Tcirg1', 18: 'Ndufs8', 19: 'Aldh3b1', 20: 'Unc93b1', 21: 'Aldh3b3', 22: 'Aldh3b2', 23: 'Acy3', 24: 'Tbx10', 25: 'Nudt8', 26: 'Doc2g', 27: 'Ndufv1', 28: 'Gstp1', 29: 'Gstp2', 30: 'Gstp3', 31: 'Cabp2', 32: 'Cabp2', 33: 'Cdk2ap2', 34: 'Pitpnm1', 35: 'Pitpnm1', 36: 'Aip', 37: 'Aip', 38: 'Tmem134', 39: 'Cabp4', 40: 'Gpr152', 41: 'Coro1b', 42: 'Ptprcap', 43: 'Rps6kb2', 44: 'Carns1', 45: 'Tbc1d10c', 46: 'Ppp1ca', 47: 'Rad9a', 48: 'Clcf1', 49: 'Clcf1', 50: 'Pold4', 51: 'Pold4', 52: 'Mir6985', 53: 'Ssh3', 54: 'Ssh3', 55: 'Ankrd13d', 56: 'Grk2', 57: 'Kdm2a', 58: 'A930001c03rik', 59: 'Rhod', 60: 'A930001c03rik', 61: 'Syt12', 62: '2010003k11rik', 63: 'Pcx', 64: 'Pcx', 65: 'Lrfn4', 66: 'Mir6986', 67: 'Rce1', 68: 'Rce1', 69: 'Gm960', 70: 'Sptbn2', 71: 'Rb

To test that the model can learn under favorable circumstances, we will test how well the model can predict TG expression for the most commonly-expressed TG

In [105]:
chr19_genes = gene_tss_df["name"].str.capitalize().to_list()
rna_df_chr19 = mesc_rna_data[mesc_rna_data.index.isin(chr19_genes)]
rna_df_chr19_most_common_genes = rna_df_chr19.fillna(0).ne(0).sum(axis=1).sort_values(ascending=False)
top_50_expressed_genes = rna_df_chr19_most_common_genes.index[:50].to_numpy()
top_50_expressed_genes

array(['Fau', 'Malat1', 'Eif3a', 'Eef1g', 'Cfl1', 'Fth1', 'Pgam1',
       'Snhg1', 'Tnks2', 'Hells', 'Cox8a', 'Ddb1', 'Btaf1', 'Scd2',
       'Stip1', 'Tcf7l2', 'Gldc', 'Sf3b2', 'Nolc1', 'Atp5md', 'Ppp1r14b',
       'Rtn3', 'Uhrf2', 'Smc3', 'Smc5', 'Oga', 'Sf1', 'Rbp4', 'Mark2',
       'Banf1', 'Rfx3', 'Pten', 'Ganab', 'Gnaq', 'Kdm2a', 'Ppp1ca',
       'Tm9sf3', 'Pdcd11', 'Tmem132a', 'Ppp6r3', 'Tjp2', 'Tle4', 'Incenp',
       'Vti1a', 'Atrnl1', 'Kif11', 'Lcor', 'Pum3', 'Slc3a2', 'Chka'],
      dtype=object)

In [109]:
# Restrict the peak to gene distance dataframe to only include the top 50 TGs
top_50_genes_near_peaks = genes_near_peaks[genes_near_peaks["target_id"].isin(top_50_expressed_genes)]

# Subset the Homer results for peaks within range of the top 50 TGs
top_50_genes_homer_results = homer_results[homer_results["peak_id"].isin(top_50_genes_near_peaks["peak_id"])]

# Restrict RNA expression to the top 50 TGs and TFs binding to peaks within range of the top 50 TGs
top_50_genes_rna_df = rna_df_chr19[
    (rna_df_chr19.index.isin(top_50_expressed_genes)) |
    (rna_df_chr19.index.isin(top_50_genes_homer_results["source_id"]))
    ]

top_50_genes_atac_df = mesc_atac_data_chr19[mesc_atac_data_chr19.index.isin(top_50_genes_near_peaks["peak_id"])]

In [111]:
top_50_potential_tgs = top_50_genes_near_peaks["target_id"].unique()
print(f"Number of potential TGs: {len(top_50_potential_tgs)}")

top_50_unique_tfs = top_50_genes_homer_results["source_id"].unique()
print(f"Number of unique TFs: {len(top_50_unique_tfs)}")

top_50_unique_peaks = top_50_genes_atac_df.index.unique()
print(f"Number of unique peaks: {len(top_50_unique_peaks)}")

Number of potential TGs: 50
Number of unique TFs: 262
Number of unique peaks: 4931


We only want genes that are either in the potential TG list or in the unique TF list and cells that are present in both the RNA and ATAC datasets.

In [114]:
atac_cell_barcodes = top_50_genes_atac_df.columns.to_list()
rna_cell_barcodes = top_50_genes_rna_df.columns.to_list()
atac_in_rna_shared_barcodes = [i for i in atac_cell_barcodes if i in rna_cell_barcodes]

# Make sure that the cell names are in the same order and in both datasets
shared_barcodes = sorted(set(atac_in_rna_shared_barcodes))

mesc_atac_data_chr1_shared = top_50_genes_atac_df[shared_barcodes]
mesc_rna_data_shared = top_50_genes_rna_df[shared_barcodes]

In [115]:
rna_first_cell = mesc_rna_data_shared.iloc[:, 0]
atac_first_cell = mesc_atac_data_chr1_shared.loc[:, rna_first_cell.name] # Makes sure the barcodes match
print(rna_first_cell.name, atac_first_cell.name)

E7.5_rep1#AAACCGGCAGAAATGC-1 E7.5_rep1#AAACCGGCAGAAATGC-1


#### TF-peak Binding Potential

We calculate the TF-peak binding potential as the TF to peak binding score from homer * the gene expression of each TF

In [117]:
tf_peak_binding_potential = pd.merge(top_50_genes_homer_results, rna_first_cell, left_on="source_id", right_index=True, how="inner")
tf_peak_binding_potential["tf_peak_binding_potential"] = tf_peak_binding_potential["homer_binding_score"] * tf_peak_binding_potential.iloc[:,-1]
tf_peak_binding_potential = tf_peak_binding_potential[["source_id", "peak_id", "tf_peak_binding_potential"]]
tf_peak_binding_potential.head()

Unnamed: 0,source_id,peak_id,tf_peak_binding_potential
61296,Klf9,chr19:38352670-38353270,0.0
61298,Klf9,chr19:5151209-5151809,0.0
61300,Klf9,chr19:7483054-7483654,0.0
61301,Klf9,chr19:46624034-46624634,0.0
61302,Klf9,chr19:4300344-4300944,0.0


#### Peak-TG Regulatory Potential

We calculate the peak-TG regulatory potential as the peak accessibility * the peak to potential TG TSS distance score

In [118]:
peak_tg_regulatory_potential = pd.merge(top_50_genes_near_peaks, atac_first_cell, left_on="peak_id", right_index=True, how="inner")
peak_tg_regulatory_potential["peak_tg_regulatory_potential"] = peak_tg_regulatory_potential["TSS_dist_score"] * peak_tg_regulatory_potential.iloc[:, -1]
peak_tg_regulatory_potential = peak_tg_regulatory_potential[["peak_id", "target_id", "peak_tg_regulatory_potential"]]
peak_tg_regulatory_potential.head()

Unnamed: 0,peak_id,target_id,peak_tg_regulatory_potential
227332,chr19:53599784-53600384,Smc3,0.0
132396,chr19:24142459-24143059,Tjp2,0.0
125654,chr19:14596667-14597267,Tle4,0.0
137892,chr19:28010498-28011098,Rfx3,0.0
158167,chr19:38124843-38125443,Rbp4,0.0


#### TF-Peak-TG Regulatory Potential

Finally, we get the TF-Peak-TG regulatory potential by multiplying the TF-peak binding potential scores with the peak-TG regulatory potential scores for each shared peak.

In [123]:
tf_peak_tg_regulatory_potential = pd.merge(tf_peak_binding_potential, peak_tg_regulatory_potential, on="peak_id", how="outer").dropna()
tf_peak_tg_regulatory_potential["tf_peak_tg_score"] = tf_peak_tg_regulatory_potential["tf_peak_binding_potential"] * tf_peak_tg_regulatory_potential["peak_tg_regulatory_potential"]
tf_peak_tg_regulatory_potential = tf_peak_tg_regulatory_potential[["source_id", "peak_id", "target_id", "tf_peak_tg_score"]]
print(tf_peak_tg_regulatory_potential.head())
print(tf_peak_tg_regulatory_potential.shape)

  source_id                  peak_id target_id  tf_peak_tg_score
5     Glis3  chr19:10012839-10013439      Fth1               0.0
6     Glis3  chr19:10012839-10013439      Fth1               0.0
7     Glis3  chr19:10012839-10013439    Incenp               0.0
8     Glis3  chr19:10012839-10013439      Ddb1               0.0
9     Glis3  chr19:10012839-10013439  Tmem132a               0.0
(11431, 4)


### Aggregate the TF-peak-TG scores into genomic coordinate windows

We will not aggregate the peaks based on the 1 Kb windows that we created earlier. The peak scores for a window will be summed to get a final score. By aggregating the peaks into windows with a static size, we ensure that the transformer architecture will work with any mm10 data.

In [127]:
# Parse peak_id into genomic coords (chrom, start, end)
coords = tf_peak_tg_regulatory_potential["peak_id"].str.extract(
    r"(?P<chrom>[^:]+):(?P<start>\d+)-(?P<end>\d+)"
).astype({"start":"int64","end":"int64"})

df = pd.concat([tf_peak_tg_regulatory_potential, coords], axis=1)

df = df[df["chrom"] == "chr19"].copy()

window_size = int((mm10_chr19_windows["end"] - mm10_chr19_windows["start"]).mode().iloc[0])

# Build a quick lookup of window_id strings from window indices
# window index k -> [start=k*w, end=(k+1)*w)
win_lut = {}
for _, row in mm10_chr19_windows.iterrows():
    k = row["start"] // window_size
    win_lut[k] = f'{row["chrom"]}:{row["start"]}-{row["end"]}'

# --- Assign each unique peak to the window with maximal overlap (random ties) ---
rng = np.random.default_rng(0)  # set a seed for reproducibility; change/remove if you want different random choices

peaks_unique = (
    df.loc[:, ["peak_id", "chrom", "start", "end"]]
      .drop_duplicates(subset=["peak_id"])
      .reset_index(drop=True)
)

def assign_best_window(start, end, w):
    # windows indices spanned by the peak (inclusive)
    i0 = start // w
    i1 = (end - 1) // w  # subtract 1 so exact boundary end==k*w goes to k-1 window
    if i1 < i0:
        i1 = i0
    # compute overlaps with all spanned windows
    overlaps = []
    for k in range(i0, i1 + 1):
        bin_start = k * w
        bin_end = bin_start + w
        ov = max(0, min(end, bin_end) - max(start, bin_start))
        overlaps.append((k, ov))
    # choose the k with max overlap; break ties randomly
    ov_vals = [ov for _, ov in overlaps]
    max_ov = max(ov_vals)
    candidates = [k for (k, ov) in overlaps if ov == max_ov]
    if len(candidates) == 1:
        return candidates[0]
    else:
        return rng.choice(candidates)

peak_to_window_idx = peaks_unique.apply(
    lambda r: assign_best_window(r["start"], r["end"], window_size), axis=1
)
peaks_unique["window_idx"] = peak_to_window_idx
peaks_unique["window_id"] = peaks_unique["window_idx"].map(win_lut)

# Map window assignment back to the full TF–peak–gene table and aggregate
df = df.merge(
    peaks_unique.loc[:, ["peak_id", "window_id"]],
    on="peak_id",
    how="left"
)

# Aggregate scores per TF × window × gene
binned_scores = (
    df.groupby(["source_id", "window_id", "target_id"], observed=True)["tf_peak_tg_score"]
      .sum()
      .reset_index()
).rename(columns={"tf_peak_tg_score":"tf_window_tg_score"})

print(binned_scores.head())


  source_id                window_id target_id  tf_window_tg_score
0     Glis3  chr19:10013000-10014000      Ddb1                 0.0
1     Glis3  chr19:10013000-10014000      Fth1                 0.0
2     Glis3  chr19:10013000-10014000    Incenp                 0.0
3     Glis3  chr19:10013000-10014000  Tmem132a                 0.0
4     Glis3  chr19:10015000-10016000      Ddb1                 0.0


Next, we need to pivot the long dataframe into a 3D TF x window x TG NumPy array.

In [128]:
# Get unique IDs
tfs = binned_scores["source_id"].unique()
windows = binned_scores["window_id"].unique()
genes = binned_scores["target_id"].unique()

# Create index maps
tf_idx = {tf: i for i, tf in enumerate(tfs)}
window_idx = {p: i for i, p in enumerate(windows)}
gene_idx = {g: i for i, g in enumerate(genes)}

# Initialize 3D matrix
tensor_np = np.zeros((len(tfs), len(windows), len(genes)), dtype=float)

# Fill values
for _, row in binned_scores.iterrows():
    i = tf_idx[row["source_id"]]
    j = window_idx[row["window_id"]]
    k = gene_idx[row["target_id"]]
    tensor_np[i, j, k] = row["tf_window_tg_score"]

print(tensor_np.shape)  # (n_TFs, n_windows, n_genes)

(2, 2484, 50)


Now that we have the input matrix, we will store it in case the kernel crashes

In [124]:
np.savez_compressed(
    os.path.join(output_dir, "tf_window_gene_tensor.npz"),
    tensor=tensor_np,
    tfs=tfs,
    windows=windows,
    genes=genes
)

In [129]:
print(f"Number of TFs: {tensor_np.shape[0]}")
print(f"Number of Windows: {tensor_np.shape[1]}")
print(f"Number of TGs: {tensor_np.shape[2]}")

Number of TFs: 2
Number of Windows: 2484
Number of TGs: 50


To recap, the scores in this array represnts the TF-peak regulatory potentials for each TG based on the Homer TF-peak binding scores, the peak-TG regulatory potential scores, the peak accessibility values, and the TF RNA expression values

---

## Transformer Architecture

We will use the genomic windows as the sequence and let each gene query the windows for evidence to predict its gene expression. The tokens will be the 12,480 windows for Chr1. The features for the tokens will be set using a learned linear projection of the TF x TG axis.

### Data Normalization and Range Clamping

We convert the numpy array into a PyTorch tensor, then standardize the distribution of the data using Z-score normalization and clamp the min and max values between -5 and 5

In [140]:
X = torch.from_numpy(tensor_np).float()     # 106 TF x 12,480 windows x 1,425 TGs

# Normalize the window data per TF across the TGs
X = (X - X.mean(dim=2, keepdim=True)) / (X.std(dim=2, keepdim=True) + 1e-6)
X = torch.clamp(X, -5, 5)

### Linear projecting of the TF and TG features

Running a transformer model on a 106 x 12480 x 1425 matrix will be too computationally heavy. Instead, we will use trainable linear projections to reduce the dimenisonality of the TF and TG features down to a 256-dimension embedding.

In [141]:
# We will use a dimension 256 features per window
TF, W, TG = X.shape
d_model = 256

# Project the TG dimension down to 64 using a linear projection
tg_channels = 64 
proj_tg = nn.Linear(TG, tg_channels, bias=False)
Xg = proj_tg(X)             # TF, W, 64

# Project the TF dimension down to 64 using a linear projection
tf_channels = 32
proj_tf = nn.Linear(TF, tf_channels, bias=False)

# Need to permute Xg to get the TFs as the last dimension before projecting
Xg = Xg.permute(1, 2, 0)    # W, 64, TF
Xg = proj_tf(Xg)            # W, 64, 32

# Next, we flatten the 64 tg_channels x 32 tf_channels
window_features = Xg.reshape(W, tg_channels * tf_channels)

# Now we project the window features to the shape of d_model
proj_window = nn.Linear(tg_channels * tf_channels, d_model)
tokens = proj_window(window_features)   # W, 256


In [142]:
tokens.shape

torch.Size([2484, 256])

### Downsampling the Windows

A window size of 12480 is still too large to run, given that the computational resources required to run attention scale exponentially with the dimensionality of the window size. We will pool the tokens using a kernel size of 8 and a stride of 8. This bins and averages the values to reduce the dimensionality.

In [133]:
# Downsample by average pooling across windows
pool = nn.AvgPool1d(kernel_size=8, stride=8)  # along sequence length, bins and pools the data
tokens = tokens.unsqueeze(0)
tokens_ds = pool(tokens.transpose(1,2)).transpose(1,2)   # [1, W', d_model]
W_ds = tokens_ds.size(1)

In [134]:
tokens_ds.shape

torch.Size([1, 310, 256])

### Building the transformer encoder layer

Next, we set up a transformer with 8 heads and a feedforward dimension of 512.

In [135]:
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model, nhead=8, dim_feedforward=512, batch_first=True
)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)

H = encoder(tokens_ds)   # [1, W', d_model]

### Positional Encoding (Not used for now)

Next, we use a positional encoder to encode the positions of the genomic windows. This will allow the model to learn context about whether a gene should be expressed using information about the regulatory potential landscape around it.

We will use a local relative position bias (RPB) to add a learnable bias for windows within 10 Kb of each gene. This will help the attention mechanism to learn how much it should bias closer windows compared to further windows. 

In [None]:
class RelativePositionBias(nn.Module):
    """
    Learned relative position bias for self-attention over windows.
    - max_kb: e.g., 10 (±10 kb neighborhood)
    - window_size_bp: e.g., 1000 for 1kb windows
    - n_heads: attention heads to broadcast the bias over
    """
    def __init__(self, n_heads: int, window_size_bp: int, max_kb: int):
        super().__init__()
        self.n_heads = n_heads
        self.window_size_bp = window_size_bp
        self.max_offset = max(1, (max_kb * 1000) // window_size_bp)  # in window units
        n_buckets = 2 * self.max_offset + 1  # offsets from -max..+max
        self.bias = nn.Parameter(torch.zeros(n_heads, n_buckets))     # [H, 2*max+1]
        nn.init.zeros_(self.bias)

    def forward(self, seq_len: int, device=None):
        """
        Returns an additive bias tensor for attention scores:
        shape [H, L, L], to be added to the (H, L, L) logits.
        """
        device = device or self.bias.device
        # offsets[i,j] = j - i  in window units
        idx = torch.arange(seq_len, device=device)
        off = idx[None, :] - idx[:, None]   # [L, L] in windows

        # clip to [-max_offset, +max_offset]
        off = off.clamp(-self.max_offset, self.max_offset)
        # map offset -> bucket index [0..2*max_offset], center at 0 offset
        buckets = (off + self.max_offset).long()  # [L, L]

        # gather per-head bias
        # bias: [H, B], buckets: [L, L] -> out [H, L, L]
        out = self.bias[:, buckets]  # advanced indexing
        return out  # [H, L, L]

### Cross-attention from the gene queries to the window tokens

Each TG has a learned Query vector, and the window tokens each have Key, Value vectors. This allows the model to learn to predict the TG expression using context about the values of the windows around it.

#### Set up the multi-headed attention

In [136]:
n_genes = len(genes)  # 1425
gene_embed = nn.Embedding(n_genes, d_model)

cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, batch_first=True)
readout = nn.Sequential(
    nn.LayerNorm(d_model),
    nn.Linear(d_model, 1)
)

# Build gene queries (index 0..n_genes-1)
gene_ids = torch.arange(n_genes)
GQ = gene_embed(gene_ids).unsqueeze(0)  # [1, n_genes, d_model]

# Cross-attention: (Q=genes, K/V=windows)
Z, _ = cross_attn(query=GQ, key=H, value=H)  # [1, n_genes, d_model]

pred_expr = readout(Z).squeeze(-1)           # [1, n_genes]

#### Register Parameters

In [137]:
params = (
    list(proj_tg.parameters()) +
    list(proj_tf.parameters()) +
    list(proj_window.parameters()) +
    list(encoder.parameters()) +
    list(gene_embed.parameters()) +
    list(cross_attn.parameters()) +
    list(readout.parameters())
    # + list(rpb.parameters())  # if you use the custom RPB blocks
)
opt = torch.optim.AdamW(params, lr=1e-3, weight_decay=1e-4)

#### Move everything to the same CUDA device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X = X.to(device)
proj_tg, proj_tf, proj_window, encoder = proj_tg.to(device), proj_tf.to(device), proj_window.to(device), encoder.to(device)
gene_embed, cross_attn, readout = gene_embed.to(device), cross_attn.to(device), readout.to(device)
tokens_ds = tokens_ds.to(device)

#### Get the true TG expression values to calculate error

The TG expression vector needs to be in the same order and have the same length as the TG predictions (genes dictionary)

In [138]:
dup = pd.Index(genes).duplicated()
assert not dup.any(), f"Duplicate gene IDs in prediction axis at: {np.where(dup)[0][:10]}"

# Align counts to prediction order from the gene to index mapping (same length and order as genes)
true_counts = rna_first_cell.reindex(genes)

# build mask for missing genes (not present in RNA)
mask = ~true_counts.isna().to_numpy()

# Handle missing genes using a masked loss 
y_true_vec = true_counts.to_numpy(dtype=float)        # shape (n_genes,)

y_true = torch.tensor(y_true_vec, dtype=torch.float32).unsqueeze(0)   # [1, n_genes]
mask_t = torch.tensor(mask, dtype=torch.bool).unsqueeze(0)            # [1, n_genes]

### Loss Calculation

We will calculate the MSE between the true expression values of the TGs and the predicted TG expression values.

In [139]:
def masked_mse(pred, y, m):
    diff2 = (pred - y)**2
    return diff2[m].mean()

# pred_expr: [1, n_genes] from the model
loss = masked_mse(pred_expr, y_true, mask_t)
print(loss)

tensor(21.2766, grad_fn=<MeanBackward0>)
