In [None]:
#Need to specify the path to arsenal_chrombpnet repository
arsenal_chrombpnet_path = "/users/patelas/arsenal-chrombpnet/"

In [None]:
import torch
import os
import numpy as np
from torch.utils import data
from torch.nn import DataParallel
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt 
import pandas as pd
import seaborn as sns
import json
import tqdm
import sys
import pyfaidx
sys.path.append("../src/regulatory_lm/")
from evals.nucleotide_dependency import *
from modeling.model import *
from utils.viz_sequence import *
sys.path.append(f"{arsenal_chrombpnet_path}/chrombpnet/")
from bpnet import BPNet


This notebook demonstrates how to use ARSENAL as part of a guided sequence generation pipeline. This pipeline employs a beam search strategy, with several rounds of generation. After each round, we score our generated sequences and keep the top $k$ according to a user-defined objective function. 

The notebook is optimized to use the supervised model ChromBPNet as an oracle to score generations, as in the examples in the ARSENAL paper. However, this is not necessary. Any other model - or no model - can be used by defining the appropriate objective functions. 

# Define some relevant functions

In [None]:
def dna_to_one_hot(seqs):
    """
    Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the
    position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list
    of N strings, where every string is the same length L. Returns an N x L x 4
    Pytorch tensor of one-hot encodings, in the same order as the input sequences.
    All bases will be converted to upper-case prior to performing the encoding.
    Any bases that are not "ACGT" will be given an encoding of all 0s.
    """
    seq_len = len(seqs[0])
    assert np.all(np.array([len(s) for s in seqs]) == seq_len)

    # Join all sequences together into one long string, all uppercase
    seq_concat = "".join(seqs).upper() + "ACGT"
    # Add one example of each base, so np.unique doesn't miss indices later

    one_hot_map = np.identity(5)[:, :-1].astype(np.int8)

    # Convert string into array of ASCII character codes;
    base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8)

    # Anything that's not an A, C, G, or T gets assigned a higher code
    base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85

    # Convert the codes into indices in [0, 4], in ascending order by code
    _, base_inds = np.unique(base_vals, return_inverse=True)

    # Get the one-hot encoding for those indices, and reshape back to separate
    return torch.tensor(one_hot_map[base_inds[:-4]].reshape((len(seqs), seq_len, 4))).float()


model_str_dict = MODULES
FLOAT_DTYPES = {"float32":torch.float32, "float64":torch.float64, "bfloat16":torch.bfloat16, "float16":torch.float16}
MAPPING = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
REVERSE_MAPPING = {0: "A", 1: "C", 2: "G", 3: "T"}

def encode_sequence(sequence): 
    encoded_sequence = [MAPPING.get(nucleotide, 4) for nucleotide in sequence]
    return encoded_sequence

def encode_sequence_tensor(sequence, device):
    return torch.tensor(encode_sequence(sequence)).to(device)

def revcomp(seq_list):
    return [3-x for x in seq_list][::-1]

def revcomp_string(dna_sequence):
    complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G'}
    return ''.join(complement[base] for base in reversed(dna_sequence.upper()))

def decode_seq_tensor(seq_tensor):
    return "".join(REVERSE_MAPPING[x] for x in seq_tensor.tolist())


def load_model(args_json, saved_model_file):
    args = json.load(open(args_json, "r"))
    embedder_kwargs = args.get("embedder_kwargs", {})
    encoder_kwargs = args.get("encoder_kwargs", {})
    decoder_kwargs = args.get("decoder_kwargs", {})
    model_kwargs = args.get("model_kwargs", {})


    embedder = model_str_dict[args["embedder"]](args["embedding_size"], vocab_size=args["num_real_tokens"]+2, masking=True, **embedder_kwargs)
    encoder = model_str_dict[args["encoder"]](args["embedding_size"], args["num_encoder_layers"], **encoder_kwargs)
    decoder = model_str_dict[args["decoder"]](args["embedding_size"], **decoder_kwargs)
    model = RegulatoryLM(embedder, encoder, decoder)
    model_info = torch.load(saved_model_file)
    if list(model_info["model_state"].keys())[0][:7] == "module.":
        model_info["model_state"] = {x[7:]:model_info["model_state"][x] for x in model_info["model_state"]}
    else:
        model = torch.compile(model)
    model.load_state_dict(model_info["model_state"])
    model.eval()
    return model


# Define the generation functions

`target_generation` is the function that runs the overall targeted generation algorithm. Please see the ARSENAL paper for a detailed description of the algorithm. Here are some key parameters:

* `arsenal_model` - ARSENAL model to use
* `supervised_models` - list of supervised model(s) to use for scoring generations. Can be an empty argument if you don't want to use any supervised models; just needs to be compatible with the rest of the pipeline
* `starting_seq` - starting sequence for the algorithm. Can be any length, and the algorithm will edit the central 350bp
* `count_target` - target values to compare with the scores you produce. Can be any data type or length; just needs to be compatible with the scores and objective function
* `seqs_to_gen` - number of sequences to generate (and keep after every step in our beam search)
* `iters_per_seq` - number of iterations of masking and resampling in each generation step
* `mask_ratio` - fraction of tokens to mask
* `temp` - temperature for scaling of logits before resampling


The other key function to edit is `supervised_predict_counts`. This function takes in a sequence (potentially with an edited center) and uses a supervised model to predict a particular value to eventually compare with one of the target values. The current sequence works for a ChromBPNet model, but if another model is desired, then it will have to be changed. If the user desires a value that does not involve any model predictions, then the function can take a dummy argument for the model and define the calculations accordingly. Note that this function is run separately for every supervised model specified. 


In [None]:
def extract_seqs_and_context(start_seq, extract_len):
    '''
    Given a sequence as a string, extracts the center "extract_len" positions and separates those from the rest
    '''
    start_len = len(start_seq)
    center_start, center_end = start_len // 2 - extract_len // 2, start_len // 2 + extract_len // 2
    center_seq = start_seq[center_start : center_end]
    left, right = start_seq[:center_start], start_seq[center_end:]
    return center_seq, left, right

def filter_best(seqs_and_errors, num_to_keep):
    '''
    Expects seqs_and_scores to contain tuple of (seq, score) where lower is better
    Takes the top k according to num_to_keep
    '''
    return sorted(seqs_and_errors, key=lambda x: x[1])[:num_to_keep]

In [None]:
#Here are the functions for actual sequence generation
def mask_and_resample(model, seq, mask_token=5, mask_ratio=0.15, temp=1.0):
    # Step 1: Generate mask
    to_mask = torch.rand(seq.shape, device=seq.device) < mask_ratio
    # Step 2: Mask sequence
    masked_seq = torch.where(to_mask, mask_token, seq)
    # Step 3: Forward pass through model
    with torch.no_grad():
        logits = model(masked_seq.unsqueeze(0), None) / temp  # shape: (1, N, V)
    # Step 4: Get probabilities
    probs = F.softmax(logits, dim=-1).squeeze(0)  # shape: (N, V)
    # Step 5: Sample for *masked positions only*
    sampled_indices = torch.multinomial(probs[to_mask], num_samples=1).squeeze(1)  # shape: (num_masked,)

    # Step 6: Create output by replacing masked positions with sampled values
    out = seq.clone()
    out[to_mask] = sampled_indices

    return out


def iterative_generation(model, seq, iters=100, mask_token=5, mask_ratio=0.15, temp=1.0):
    for gen_run in range(iters):
        seq = mask_and_resample(model, seq, mask_token, mask_ratio, temp)
    return seq

In [None]:
#Here are the functions to perform and evaluate generations
def generate_and_predict(arsenal_model, supervised_models, seq, context, iters_per_seq, mask_token, mask_ratio, temp, device):
    '''
    Given a starting sequence, generates a new sequence from it
    Combines the sequences with the surrounding left and right context and uses supervised model to predict counts
    Returns the generated central sequence and counts
    '''
    seq_tensor = encode_sequence_tensor(seq, device)
    new_seq = iterative_generation(arsenal_model, seq_tensor, iters_per_seq, mask_token, mask_ratio, temp)
    new_seq_str = decode_seq_tensor(new_seq)
    supervised_full_seq = context[0] + new_seq_str + context[1]
    pred_counts = [supervised_predict_counts(sup_model, supervised_full_seq, device) for sup_model in supervised_models]
    return new_seq_str, pred_counts
    
def supervised_predict_counts(supervised_model, seq_str, device):
    '''
    Takes in a DNA sequence as a string and uses a supervised bpnet-style model to predict counts over the region
    '''
    one_hot_seq = dna_to_one_hot([seq_str]).to(device)
    with torch.no_grad():
        supervised_pred = supervised_model(one_hot_seq)
    return supervised_pred[1].item()


In [None]:
#Here is the full target generation loop
def target_generation(arsenal_model, supervised_models, starting_seq, count_target, device, error_fn, seqs_to_gen=100, gen_iters=50, len_to_edit=350, iters_per_seq=100,
                                mask_token=5, mask_ratio=0.15, temp=1.0):
    '''
    Performs the full ChromBPNet-aided ARSENAL generation pipeline. Works as follows:
    1. Extract central region from starting input sequence - this is what we will directly edit
    2. Generate one sequence and predict using one or more BPNet-style supervised models, add to our list
    3. Now, in each iteration, we add one new sequence per each sequence already in the list
    4. At the end of each iteration, we filter to only the top k sequences according to our error metric (if more than that exist at the time)
    5. At the end, we return our final top k
    
    Note that supervised_models and count_target must be iterables
    '''
    seqs_and_errors = [] #Each element will be a tuple of (seq, error)
    edit_seq, left_context, right_context = extract_seqs_and_context(starting_seq, len_to_edit)
    
    #Generate first sequence so we have something to populate
    first_seq, first_counts = generate_and_predict(arsenal_model, supervised_models, edit_seq, (left_context, right_context), iters_per_seq, mask_token, mask_ratio, temp, device)
    error = error_fn(count_target, first_counts)
    seqs_and_errors.append((first_seq, error))
    
    #Now go through our iterations until we're done
    #For each generated sequence we currently have, we will generate a new one from it and keep the top n
    for gen_run in range(gen_iters):
        print("Iteration ", gen_run)
        curr_iter_list = seqs_and_errors.copy()
        for (gen_seq, error) in curr_iter_list:
            new_seq, new_counts = generate_and_predict(arsenal_model, supervised_models, gen_seq, (left_context, right_context), iters_per_seq, mask_token, mask_ratio, temp, device)
            seqs_and_errors.append((new_seq, error_fn(count_target, new_counts)))
        
        #At the end of each iteration, we will keep only the best sequences for the next iteration
        seqs_and_errors = filter_best(seqs_and_errors, seqs_to_gen)
        
    return seqs_and_errors



# Define some useful objective functions

These are simple objective functions which are useful for scoring generations and form an essential part of the pipeline. 

Feel free to define your own functions. The only constraint is that they must take in two arguments: target values (`true_counts`) and some values produced from the generations (`predicted_counts`; the outputs of the calls to the `supervised_predict_counts` function) and must return a scalar score. How each of the inputs is used to produce the score (if at all) is up to the user. 

In [None]:
def absolute_error(true_counts, predicted_counts):
    return sum([abs(true_counts[x] - predicted_counts[x]) for x in range(len(true_counts))])

def squared_dist(true_counts, predicted_counts):
    true_counts, predicted_counts = np.array(true_counts), np.array(predicted_counts)
    return np.linalg.norm(true_counts - predicted_counts)


def max_diff(true_counts, predicted_counts):
    '''
    This only works with two models, maximizes the first one with respect to the second one
    (Lower score is better so we switch the subtraction)
    '''
    return predicted_counts[1] - predicted_counts[0]


# Example: ChromBPNet-guided generation pipeline

Here, we show an example of ChromBPNet-guided generation from the paper. We want to generate sequences which have high predicted counts in the HEPG2 cell line and low predicted counts in the H1-hESC cell line. We will use targets of 8 counts in the first and 1 count in the second, utilizing the absolute error objective function we defined above. 

In [None]:
#Load ARSENAL Model
args_json = "/mnt/lab_data2/regulatory_lm/scratch/transformer_test/run_20251231_230449/args.json" #Model args file
saved_model_file = "/mnt/lab_data2/regulatory_lm/scratch/transformer_test/run_20251231_230449/checkpoint_149.pt" #Model checkpoint file


model = load_model(args_json, saved_model_file).to(device)

In [None]:
#Load ChromBPNet Models
hepg2_model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/DNASE/ENCSR149XIL/chrombpnet_model/chrombpnet_wo_bias.h5" #Substitute path here
hepg2_chrombpnet_model = BPNet.from_keras(hepg2_model_file)
hepg2_chrombpnet_model = hepg2_chrombpnet_model.to(device)

h1esc_model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/DNASE/ENCSR000EMU/chrombpnet_model/chrombpnet_wo_bias.h5" #Substitute path here
h1esc_chrombpnet_model = BPNet.from_keras(h1esc_model_file)
h1esc_chrombpnet_model = h1esc_chrombpnet_model.to(device)


In [None]:
#Define genome and location - we will optimize the 2,114 bp sequence centered on this location
genome = "/mnt/lab_data2/regulatory_lm/oak_backup/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
genome_data = pyfaidx.Fasta(genome, sequence_always_upper=True)
chrom = "chr4"
seq_len = 2114
start = 39469376
end = 39469725
midpoint = (start + end) // 2
start = midpoint - seq_len // 2
end = midpoint + seq_len // 2
dna_seq = genome_data[chrom][start:end].seq


In [None]:
#We now produce one set of 100 generations - we did 5 such runs for the paper
#The function only returns the central 350bp (which is the only part that is edited), but this can easily be combined with the rest
#For the cell-type specific generations, we used targets of [8.0,1.0] and [3.0,6.0] for [HEPG2, H1ESC] and a temperature of 1.0
#For the HEPG2 counts targeting, we used targets of 5.0, 6.0, 7.0, and 8.0, with a temperature of 0.3
new_gen_seqs = target_generation(model, [hepg2_chrombpnet_model, h1esc_chrombpnet_model], dna_seq, [8.0, 1.0], device, absolute_error, gen_iters=40, iters_per_seq=20, mask_ratio=0.01, temp=1.0)
