In [None]:
#Must have chrombpnet installed - specify the path here
chrombpnet_path = "/users/patelas/lib/chrombpnet/"

In [None]:
import deepdish as dd
import json
import numpy as np
import tensorflow as tf
import pandas as pd
import shap
from tensorflow import keras
import pyfaidx
import shutil
import pickle as pkl
import pyBigWig
import errno
from scipy.spatial.distance import jensenshannon
import os
import sys
sys.path.append(f"{chrombpnet_path}/chrombpnet/")
from training.utils.losses import multinomial_nll
from training.utils.one_hot import dna_to_one_hot
from evaluation.interpret.shap_utils import *
sys.path.append("../src/regulatory_lm/")
from utils.viz_sequence import *
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

In [None]:
tf.compat.v1.disable_eager_execution()

This notebook demonstrates how to interpret ARSENAL-generated sequences using [ChromBPNet](https://github.com/kundajelab/chrombpnet/tree/master). It is necessary to use the environment specified by ChromBPNet for this notebook rather than the one provided by ARSENAL. 

In [None]:
#Here, we define the ChromBPNet h5 model file and reference genome, along with the sequence that formed the base of our generations
#Change the paths according to your system
model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/DNASE/ENCSR149XIL/chrombpnet_model/chrombpnet_wo_bias.h5"
genome = "/mnt/lab_data2/regulatory_lm/oak_backup/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
genome_data = pyfaidx.Fasta(genome, sequence_always_upper=True)
chrom = "chr4"
seq_len = 2114
start = 39469376
end = 39469725
midpoint = (start + end) // 2
start = midpoint - seq_len // 2
end = midpoint + seq_len // 2
dna_seq = genome_data[chrom][start:end].seq


In [None]:
#new_middle_seq is the ARSENAL-generated sequence we want to insert in the middle
#882 and 1232 are the start and end if the insertion is 350bp, have to edit if not
new_middle_seq = "CAGGGCACCAGAGTGAGTCCCTGCCAAATCAGATCACGCGGTTTTAGGCTGGAATTCTGCTACCCAGGCCAAGAGACTATCAAAACCTCACAGCATTCCTCAGGCACTTCAGGGTGGGACCACCAGGTGGTGCTATCATTGCACCATTCACCTAAAACTTTGGCTGTGTTCCAGGAAGCGACCGAGAAGTTGCTGCTCTGCTGAAATGTATTTGCATGTACAGTTTATTGTCAGCAGAGGTCACTGTTGTACTAGTATAAATTCCCGAGGGAAAAAATTGCAAGACTATTCACTTCTTTCTTCTTTCTTTTTTTTTTTTGGCCCCTCTTTTCTTTTTATTTTTTTAATCT"
center_start, center_end = 882, 1232

dna_seq = dna_seq[:center_start] + new_middle_seq + dna_seq[center_end:]

In [None]:
def generate_shap_dict(seqs, scores):
    print(seqs.shape, scores.shape)
    assert(seqs.shape==scores.shape)
    assert(seqs.shape[2]==4)

    # construct a dictionary for the raw shap scores and the
    # the projected shap scores
    # MODISCO workflow expects one hot sequences with shape (None,4,inputlen)
    d = {
            'raw': {'seq': np.transpose(seqs, (0, 2, 1))},
            'shap': {'seq': np.transpose(scores, (0, 2, 1))},
            'projected_shap': {'seq': np.transpose(seqs*scores, (0, 2, 1))}
        }

    return d

def softmax(x, temp=1):
    norm_x = x - np.mean(x,axis=1, keepdims=True)
    return np.exp(temp*norm_x)/np.sum(np.exp(temp*norm_x), axis=1, keepdims=True)


In [None]:
def plot_pred_out(preds):
    logits, counts = preds
    probs = softmax(logits)
    vals = np.exp(counts[0]) * probs
    plt.plot(vals[0])


In [None]:
#Here, we use the model to predict on the sequence
with keras.utils.CustomObjectScope({'multinomial_nll':multinomial_nll, 'tf':tf}):
    model = keras.models.load_model(model_file)

#Get model predictions
one_hot_seqs = dna_to_one_hot([dna_seq])

outlen = model.output_shape[0][1]

profile_model_input = model.input
profile_input = one_hot_seqs
counts_model_input = model.input
counts_input = one_hot_seqs

model_out = model.predict(one_hot_seqs)

In [None]:
#We plot the predicted values
plt.figure(dpi=300, figsize=(16,4))
plot_pred_out(model_out)
plt.show()

In [None]:
#Print the predicted counts so we can verify they match what we want
logits, counts = model_out
counts

In [None]:
#Calculate importance scores for our sequence using the model
profile_model_counts_explainer = shap.explainers.deep.TFDeepExplainer(
    (counts_model_input, tf.reduce_sum(model.outputs[1], axis=-1)),
    shuffle_several_times,
    combine_mult_and_diffref=combine_mult_and_diffref)
counts_shap_scores = profile_model_counts_explainer.shap_values(
    counts_input, progress_message=100)
counts_scores_dict = generate_shap_dict(one_hot_seqs, counts_shap_scores)

In [None]:
#Plot importance scores
plt.figure(dpi=300)
max_score = counts_scores_dict["projected_shap"]["seq"].max()
min_score = counts_scores_dict["projected_shap"]["seq"].min()
plot_weights(counts_scores_dict["projected_shap"]["seq"][0][:,seq_len // 2 - 175 : seq_len // 2 + 175], subticks_frequency=100)
plt.title("ChromBPNet Importance Scores", fontsize=20)
plt.show()
