In [None]:
import torch
import os
import numpy as np
from torch.utils import data
from torch.nn import DataParallel
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt 
import pandas as pd
import seaborn as sns
import json
import tqdm
import sys
import pyfaidx
sys.path.append("../src/regulatory_lm/")
from evals.nucleotide_dependency import *
from modeling.model import *
from utils.viz_sequence import *
from utils.bpnet import BPNet


This code performs supervised count prediction on two sets of generated sequences using the ChromBPNet models which were used as targets for these celltype-specific generations. It is intended to verify that the generations are indeed cell type-specific. 

In [None]:
#Here, we load the genome and define the reference sequence used to make insertions
#You'll need to replace the genome with your own path
genome = "/mnt/lab_data2/regulatory_lm/oak_backup/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
genome_data = pyfaidx.Fasta(genome, sequence_always_upper=True)
chrom = "chr4"
seq_len = 2114
start = 39469376
end = 39469725
midpoint = (start + end) // 2
start = midpoint - seq_len // 2
end = midpoint + seq_len // 2
print(midpoint, start, end)
dna_seq = genome_data[chrom][start:end].seq


In [None]:
def dna_to_one_hot(seqs):
    """
    Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the
    position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list
    of N strings, where every string is the same length L. Returns an N x L x 4
    Pytorch tensor of one-hot encodings, in the same order as the input sequences.
    All bases will be converted to upper-case prior to performing the encoding.
    Any bases that are not "ACGT" will be given an encoding of all 0s.
    """
    seq_len = len(seqs[0])
    assert np.all(np.array([len(s) for s in seqs]) == seq_len)

    # Join all sequences together into one long string, all uppercase
    seq_concat = "".join(seqs).upper() + "ACGT"
    # Add one example of each base, so np.unique doesn't miss indices later

    one_hot_map = np.identity(5)[:, :-1].astype(np.int8)

    # Convert string into array of ASCII character codes;
    base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8)

    # Anything that's not an A, C, G, or T gets assigned a higher code
    base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85

    # Convert the codes into indices in [0, 4], in ascending order by code
    _, base_inds = np.unique(base_vals, return_inverse=True)

    # Get the one-hot encoding for those indices, and reshape back to separate
    return torch.tensor(one_hot_map[base_inds[:-4]].reshape((len(seqs), seq_len, 4))).float()


In [None]:
def supervised_predict_counts(supervised_model, seq_str, device):
    '''
    Takes in a DNA sequence as a string and uses a supervised bpnet-style model to predict counts over the region
    '''
    one_hot_seq = dna_to_one_hot([seq_str]).to(device)
    with torch.no_grad():
        supervised_pred = supervised_model(one_hot_seq)
    return supervised_pred[1].item()


In [None]:
#Here, we define the two ChromBPNet models we are using (again replace with your own path)
#In the paper example, we had one from HEPG2 and one from H1-hESC
hepg2_model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/DNASE/ENCSR149XIL/chrombpnet_model/chrombpnet_wo_bias.h5"
hepg2_chrombpnet_model = BPNet.from_keras(hepg2_model_file)
hepg2_chrombpnet_model = hepg2_chrombpnet_model.to(device)

h1esc_model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/DNASE/ENCSR000EMU/chrombpnet_model/chrombpnet_wo_bias.h5"
h1esc_chrombpnet_model = BPNet.from_keras(h1esc_model_file)
h1esc_chrombpnet_model = h1esc_chrombpnet_model.to(device)


In [None]:
#We will now take in our generated sequences (again replace with your path)
hepg2_seqs_file = "/mnt/lab_data2/regulatory_lm/scratch/transformer_test/run_20251231_230449/generated_seqs/hepg2_high_modisco/hepg2_high_vs_h1esc_allruns.txt"
hepg2_seqs = [x.strip() for x in open(hepg2_seqs_file, "r")]

h1esc_seqs_file = "/mnt/lab_data2/regulatory_lm/scratch/transformer_test/run_20251231_230449/generated_seqs/h1esc_high_modisco/h1esc_high_vs_hepg2_allruns.txt"
h1esc_seqs = [x.strip() for x in open(h1esc_seqs_file, "r")]


In [None]:
#We perform counts prediction using both models on both sets of sequences
hepg2_counts_high, hepg2_counts_low = [], []
for seq in hepg2_seqs:
    center_start, center_end = 882, 1232
    dna_seq = dna_seq[:center_start] + seq + dna_seq[center_end:]
    with torch.no_grad():
        hepg2_counts_high.append(supervised_predict_counts(hepg2_chrombpnet_model, dna_seq, device))

for seq in h1esc_seqs:
    center_start, center_end = 882, 1232
    dna_seq = dna_seq[:center_start] + seq + dna_seq[center_end:]
    with torch.no_grad():
        hepg2_counts_low.append(supervised_predict_counts(hepg2_chrombpnet_model, dna_seq, device))
    
    
h1esc_counts_high, h1esc_counts_low = [], []
for seq in hepg2_seqs:
    center_start, center_end = 882, 1232
    dna_seq = dna_seq[:center_start] + seq + dna_seq[center_end:]
    with torch.no_grad():
        h1esc_counts_low.append(supervised_predict_counts(h1esc_chrombpnet_model, dna_seq, device))

for seq in h1esc_seqs:
    center_start, center_end = 882, 1232
    dna_seq = dna_seq[:center_start] + seq + dna_seq[center_end:]
    with torch.no_grad():
        h1esc_counts_high.append(supervised_predict_counts(h1esc_chrombpnet_model, dna_seq, device))


In [None]:
np.mean(hepg2_counts_high), np.mean(hepg2_counts_low)

In [None]:
np.mean(h1esc_counts_high), np.mean(h1esc_counts_low)

In [None]:
#We can plot the stats from both models and verify the sequences are as desired
import seaborn as sns
plt.figure(dpi=300, figsize=[6,2])
sns.kdeplot(hepg2_counts_high, lw=2, label="HEPG2 High")
sns.kdeplot(hepg2_counts_low, lw=2, label="H1ESC High")
plt.title("HEPG2 ChromBPNet Predicted Counts")
plt.xlabel("Predicted Counts")
plt.yticks([])
plt.ylabel("Density")
plt.legend()
plt.show()


In [None]:
plt.figure(dpi=300, figsize=[6,2])
plt.title("H1ESC ChromBPNet Predicted Counts")
sns.kdeplot(h1esc_counts_low, lw=2, label="HEPG2 High")
sns.kdeplot(h1esc_counts_high, lw=2, label="H1ESC High")
plt.xlabel("Predicted Counts")
plt.ylabel("Density")
plt.yticks([])
plt.legend()
plt.show()
