In [None]:
import torch
import os
import numpy as np
from torch.utils import data
from torch.nn import DataParallel
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, AutoModelForCausalLM, BertConfig, BertForPreTraining, BertForMaskedLM, AutoConfig, AutoModelForSequenceClassification
import transformers
import matplotlib.pyplot as plt 
import pandas as pd
import seaborn as sns
import scipy
import json
import tqdm
import sys
import pyfaidx
sys.path.append("../src/regulatory_lm/")
from evals.nucleotide_dependency import *
from modeling.model import *
from utils.viz_sequence import *

In [None]:
model_str_dict = MODULES
FLOAT_DTYPES = {"float32":torch.float32, "float64":torch.float64, "bfloat16":torch.bfloat16, "float16":torch.float16}


In [None]:
MAPPING = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    
def encode_sequence(sequence): 
    encoded_sequence = [MAPPING.get(nucleotide, 4) for nucleotide in sequence]
    return encoded_sequence

def revcomp(seq_list):
    return [3-x for x in seq_list][::-1]

def revcomp_string(dna_sequence):
    complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G'}
    return ''.join(complement[base] for base in reversed(dna_sequence.upper()))


In [None]:
genome = "/oak/stanford/groups/akundaje/patelas/regulatory_lm/data/hg38_repeat_lowercase.fa" #Replace with your path
genome_data = pyfaidx.Fasta(genome, sequence_always_upper=True)
#Size does not matter - we analyze a stretch of size seq_len centered around the provided location
chrom = "chr3"
seq_len = 350
start = 4868352
end = 4868665
midpoint = (start + end) // 2
start = midpoint - seq_len // 2
end = midpoint + seq_len // 2
print(midpoint, start, end)
dna_seq = genome_data[chrom][start:end].seq
seq_tensor = torch.tensor(encode_sequence(dna_seq))


# HyenaDNA

In [None]:
model_name = f"LongSafari/hyenadna-large-1m-seqlen-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="right")
model =  AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

In [None]:
seq_tensor = tokenizer([dna_seq], return_tensors = 'pt')["input_ids"]

In [None]:
def plot_scores(scores, demarcs=[], limits=None):
    plt.figure(dpi=300, figsize=[16,8])
    plot_weights(scores)
    plt.xticks([])
    max_val = np.abs(scores).max()
    if max_val < 0.05:
        plt.ylim(-0.05, 0.05)
    for motif in demarcs:
        plt.axvline(motif[0], color="black")
        plt.axvline(motif[1], color="black")
    if limits is not None:
        plt.ylim(limits[0], limits[1])
    plt.show()

In [None]:
def predict(seq, mask_inds, model, tokenizer, demarcs=[], limits=None):
    '''
    Performs inference over a sequence using HyenaDNA
    Since it's an autoregressive model, there's no masking
    We start the sequence off with a "C" to begin the likelihood calculations
    '''
    one_hot = torch.zeros(1, len(seq), 4, dtype=torch.int8)
    seq_encoded_true = encode_sequence(seq)
    seq_tensor_true = torch.tensor(seq_encoded_true)
    for nuc in range(4):
        one_hot[:,:,nuc] = (seq_tensor_true == nuc).to(dtype=torch.int8) # for non ACGT, set to 0
    one_hot = one_hot.cpu().numpy(force=True).transpose([0,2,1])
    seq_tensor = tokenizer([seq], return_tensors = 'pt')["input_ids"]
    seed_token = torch.tensor([[8]]) #Need something to start off with bc it's autoregressive
    seq_tensor = torch.cat([seed_token, seq_tensor], dim=1)
    softmax = torch.nn.Softmax(dim=-1)
    with torch.no_grad():
        seq_tensor = seq_tensor.to(device)
        model = model.to(device)
        logits = model(seq_tensor).logits
        probs_norm = softmax(logits)[:,:-2,7:11] #In addition to the first C, there's also an EOS token
    nuc_average = torch.mean(probs_norm, dim=1)
    probs_norm = (probs_norm * torch.log(probs_norm / nuc_average)).permute(0,2,1).cpu().numpy()
    plot_scores(probs_norm[:,:,min(mask_inds):max(mask_inds) + 1] * one_hot[:,:,min(mask_inds):max(mask_inds) + 1], demarcs, limits)
    plot_scores(probs_norm[:,:,min(mask_inds):max(mask_inds) + 1], demarcs, limits)
    return probs_norm * one_hot

In [None]:
probs_norm = predict(dna_seq, list(range(0, 350)), model, tokenizer, limits=[0,1.5])

# Caduceus

In [None]:
model_name = f"kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="right")
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)


In [None]:
seq_tensor = tokenizer([dna_seq], return_tensors = 'pt')["input_ids"]

In [None]:
def plot_scores(scores, demarcs=[], limits=None):
    plt.figure(dpi=300, figsize=[16,8])
    plot_weights(scores)
    plt.xticks([])
    max_val = np.abs(scores).max()
    if max_val < 0.05:
        plt.ylim(-0.05, 0.05)
    for motif in demarcs:
        plt.axvline(motif[0], color="black")
        plt.axvline(motif[1], color="black")
    if limits is not None:
        plt.ylim(limits[0], limits[1])
    plt.show()

In [None]:
def predict(seq, mask_inds, model, tokenizer, demarcs=[], limits=None):
    new_probs = []
    one_hot = torch.zeros(1, len(seq), 4, dtype=torch.int8)
    seq_encoded_true = encode_sequence(seq)
    seq_tensor_true = torch.tensor(seq_encoded_true)
    for nuc in range(4):
        one_hot[:,:,nuc] = (seq_tensor_true == nuc).to(dtype=torch.int8) # for non ACGT, set to 0
    one_hot = one_hot.cpu().numpy(force=True).transpose([0,2,1])
    softmax = torch.nn.Softmax(dim=-1)
    for ind in mask_inds:
        seq_tensor = tokenizer([seq], return_tensors = 'pt')["input_ids"]
        seq_tensor[:,ind] = 3
        with torch.no_grad():
            model.eval()
            model = model.to(device)
            logits = model(seq_tensor.to(device)).logits
            probs = F.softmax(logits, dim=-1)[:,:-1,7:11]
            probs_norm = probs.cpu().permute(0,2,1)
            new_probs.append(probs_norm[:,:,ind])
    for i, ind in enumerate(mask_inds):
        probs_norm[:,:,ind] = new_probs[i]
    probs_norm = probs_norm.permute(0,2,1)
    nuc_average = torch.mean(probs_norm, dim=1)
    probs_norm = (probs_norm * torch.log(probs_norm / nuc_average)).permute(0,2,1).cpu().numpy()
    plot_scores(probs_norm[:,:,min(mask_inds):max(mask_inds) + 1] * one_hot[:,:,min(mask_inds):max(mask_inds) + 1], demarcs, limits)
    plot_scores(probs_norm[:,:,min(mask_inds):max(mask_inds) + 1], demarcs, limits)
    return probs_norm * one_hot

In [None]:
probs_norm = predict(dna_seq, list(range(0, 350)), model, tokenizer, limits=[0,1.5])