## Compare sequence logos for attribution methods for Basset model with relu and exponential activations trained on DNase-seq data from Basset

In [None]:
import os, h5py
import numpy as np
import logomaker
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import helper
from model_zoo import basset
from tfomics import utils, explain, metrics


In [None]:
# save path
results_path = utils.os.path.join('../results', 'task4')
params_path = os.path.join(results_path, 'model_params')
save_path = os.path.join(results_path, 'conv_filters')
plot_path = utils.make_directory(results_path, 'attr_plots')

In [None]:
# load dataset
data_path = '../data/er.h5'
trainmat = h5py.File(data_path, 'r')
x_test = np.array(trainmat['test_in']).astype(np.float32)
y_test = np.array(trainmat['test_out']).astype(np.int32)
labels = np.array(trainmat['target_labels']).astype(str)
test_headers = np.array(trainmat['test_headers']).astype(str)
x_test = np.squeeze(x_test)
x_test = x_test.transpose([0,2,1])

#### List experiments with a large number of labels

In [None]:
indices = np.where(np.sum(y_test, axis=0) > 10000)[0]

for i, name in enumerate(labels[indices]):
    print(indices[i],name)

#### Set experiment

In [None]:
class_index = 4
index = np.where((y_test[:,class_index] == 1)&(np.sum(np.sum(x_test == .25, axis=2), axis=1) == 0) &(np.sum(y_test,axis=1)== 1))[0]
print(len(index))

#### Load Basset model with relu activations

In [None]:
model_name = 'basset'
activation = 'relu'

name = model_name+'_'+activation
model = basset.model(activation)

weights_path = os.path.join(params_path, name+'.hdf5')
model.load_weights(weights_path)

scores = explain.saliency(model, X, class_index=class_index, layer=-1)
scores *= X

#### Load Basset model with exponential activations

In [None]:
model_name = 'basset'
activation = 'exponential'

name = model_name+'_'+activation+'_balance'
model = basset.model(activation)

weights_path = os.path.join(params_path, name+'.hdf5')
model.load_weights(weights_path)

scores2 = explain.saliency(model, X, class_index=class_index, layer=-1)
scores2 *= X

#### Plot saliency maps for relu activations

In [None]:
plot_range = range(150,450)
N, L, A = scores.shape
L = len(plot_range)

fig = plt.figure(figsize=(25,20))
for i in range(len(scores)):

    counts_df = pd.DataFrame(data=0.0, columns=list('ACGT'), index=list(range(L)))
    for a in range(A):
        for l in range(L):
            counts_df.iloc[l,a] = scores[i,plot_range[l],a]

    ax = plt.subplot(len(scores),1,i+1)
    logomaker.Logo(counts_df, ax=ax)
    ax.yaxis.set_ticks_position('none')
    ax.xaxis.set_ticks_position('none')
    plt.xticks([])
    plt.yticks([])
    fig = plt.gcf()
    plt.ylabel(index[i], fontsize=18)

outfile = os.path.join(plot_path, 'basset_relu.pdf')
fig.savefig(outfile, format='pdf', dpi=200, bbox_inches='tight')


#### Plot saliency maps for exponential activations

In [None]:
plot_range = range(150,450)
N, L, A = scores.shape
L = len(plot_range)

fig = plt.figure(figsize=(25,20))
for i in range(len(scores)):

    counts_df = pd.DataFrame(data=0.0, columns=list('ACGT'), index=list(range(L)))
    for a in range(A):
        for l in range(L):
            counts_df.iloc[l,a] = scores2[i,plot_range[l],a]

    ax = plt.subplot(len(scores),1,i+1)
    logomaker.Logo(counts_df, ax=ax)
    ax.yaxis.set_ticks_position('none')
    ax.xaxis.set_ticks_position('none')
    plt.xticks([])
    plt.yticks([])
    fig = plt.gcf()
    plt.ylabel(index[i], fontsize=18)


outfile = os.path.join(plot_path, 'basset_exp.pdf')
fig.savefig(outfile, format='pdf', dpi=200, bbox_inches='tight')


# Compare saliency map for same region as Basset paper (Fig. 4) 

Region: Chr 9: 118,434,976–118,435,175 in H1-hESCs


In [None]:
test = 'TATAAATATTAGTTGAATGGTATGAAGTAAAACAAActtatactggtaatagctttggaatttacaaagcattttcccatgcattatgtcttctcctcctcatattaaccctgcaaacgaaataacattattacccgtactttacagaagaggacactgaagccaaaggagaaaattaactagctcagtcttgcatgacccctgtgaatggactgatcttgaaacccaggtaaccttactccCTGGTCCCAGCCTTTGTTAATGGGGACACAATCCTGGAAATTTTGCCTGTGTGTAAACCTCTAGGGGCTTTTTCTTTCATCGTTTTACATCAGCCAGACTCTGACTCACAGCTGGAGAATCAGCTTCCTTATTATGTAGCGAATTCCATGAACACACACCAAGAGTTGTTTTCTGTAACAGGCTGAAGTAGCTTCTTCTCCCAGTCTCTTTCTCCCATCAAAATTAGAATATCTTTCCTTGGAAAACTGTGCCCAGGTTGAGGGGGACTTCTCCCTGGTTTTGTGTAGACTCTTTGATATGCTCCAAACTCAACGCCTTTCCTTCAATCCCTGGGGCCTTAGGAACAGCCAACCCACA'

L = len(test)
alphabet = 'ACGT'
X = np.zeros((1,L,4))
for l,a in enumerate(test.upper()):
    X[0,l,alphabet.index(a)] = 1

In [None]:
model_name = 'basset'
activation = 'relu'

class_index = 113
name = model_name+'_'+activation
model = basset.model(activation)

weights_path = os.path.join(params_path, name+'.hdf5')
model.load_weights(weights_path)

scores = explain.saliency(model, X, class_index=class_index, layer=-1)
scores *= X

In [None]:
model_name = 'basset'
activation = 'exponential'

name = model_name+'_'+activation+'_balance'
model = basset.model(activation)

weights_path = os.path.join(params_path, name+'.hdf5')
model.load_weights(weights_path)

scores2 = explain.saliency(model, X, class_index=class_index, layer=-1)
scores2 *= X

In [None]:
plot_range = range(150,430)
N, L, A = scores.shape
L = len(plot_range)
for k, i in enumerate(range(len(scores))):

    fig = plt.figure(figsize=(25,4))

    counts_df = pd.DataFrame(data=0.0, columns=list('ACGT'), index=list(range(L)))
    for a in range(A):
        for l in range(L):
            counts_df.iloc[l,a] = scores[i,plot_range[l],a]

    ax = plt.subplot(2,1,1)
    logomaker.Logo(counts_df, ax=ax)
    ax.yaxis.set_ticks_position('none')
    ax.xaxis.set_ticks_position('none')
    plt.xticks([])
    plt.yticks([])
    fig = plt.gcf()
    plt.ylabel('Relu', fontsize=16)
    #ax2 = ax.twinx()
    #plt.ylabel(np.round(pr_score[k],4), fontsize=16)
    #plt.yticks([])

    counts_df = pd.DataFrame(data=0.0, columns=list('ACGT'), index=list(range(L)))
    for a in range(A):
        for l in range(L):
            counts_df.iloc[l,a] = scores2[i,plot_range[l],a]

    ax = plt.subplot(2,1,2)
    logomaker.Logo(counts_df, ax=ax)
    ax.yaxis.set_ticks_position('none')
    ax.xaxis.set_ticks_position('none')
    plt.xticks([])
    plt.yticks([])
    fig = plt.gcf()
    plt.ylabel('Exp', fontsize=16)
    #ax2 = ax.twinx()
    #plt.ylabel(np.round(pr_score[k],4), fontsize=16)
    #plt.yticks([])


    outfile = os.path.join(plot_path, 'basset_compare.pdf')
    fig.savefig(outfile, format='pdf', dpi=200, bbox_inches='tight')
