## Replots 1st layer filters with y-axis labels that match a ground truth motif according to tomtom

In [None]:
import os, sys
from six.moves import cPickle
import numpy as np
import pandas as pd
import logomaker
import helper
from tfomics import utils, explain
from tensorflow import keras
from keras import backend as K
import tensorflow.compat.v1.keras.backend as K1

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:

arid3 = ['MA0151.1', 'MA0601.1', 'PB0001.1']
cebpb = ['MA0466.1', 'MA0466.2']
fosl1 = ['MA0477.1']
gabpa = ['MA0062.1', 'MA0062.2']
mafk = ['MA0496.1', 'MA0496.2']
max1 = ['MA0058.1', 'MA0058.2', 'MA0058.3']
mef2a = ['MA0052.1', 'MA0052.2', 'MA0052.3']
nfyb = ['MA0502.1', 'MA0060.1', 'MA0060.2']
sp1 = ['MA0079.1', 'MA0079.2', 'MA0079.3']
srf = ['MA0083.1', 'MA0083.2', 'MA0083.3']
stat1 = ['MA0137.1', 'MA0137.2', 'MA0137.3', 'MA0660.1', 'MA0773.1']
yy1 = ['MA0095.1', 'MA0095.2']


Gmeb1 = ['MA0615.1']

motifs = [[''],arid3, cebpb, fosl1, gabpa, mafk, max1, mef2a, nfyb, sp1, srf, stat1, yy1]
motifnames = [ '','Arid3', 'CEBPB', 'FOSL1', 'Gabpa', 'MAFK', 'MAX', 'MEF2A', 'NFYB', 
               'SP1', 'SRF', 'STAT1', 'YY1']


In [None]:
# load data
data_path = '../data/synthetic_code_dataset.h5'
data = helper.load_data(data_path)
x_train, y_train, x_valid, y_valid, x_test, y_test = data

In [None]:
num_trials = 10
model_names = ['cnn-deep', 'cnn-2', 'cnn-50']
activations = ['relu', 'exponential', 'sigmoid', 'tanh', 'softplus', 'linear', 'elu']

results_path = utils.make_directory('../results', 'task1')
params_path = utils.make_directory(results_path, 'model_params')

layer = 3
threshold = 0.5
window = 20
num_cols = 8
figsize = (30,10)
size=32

save_path = os.path.join(results_path, 'conv_filters')

for model_name in model_names:
    for activation in ['relu', 'exponential']:
        for trial in range(num_trials):

            # load model
            model = helper.load_model(model_name, 
                                            activation=activation, 
                                            input_shape=200)
            name = model_name+'_'+activation+'_'+str(trial)
            print(name)

            weights_path = os.path.join(params_path, name+'.hdf5')
            model.load_weights(weights_path)

            # save path
            file_path = os.path.join(save_path, name, 'tomtom.tsv')
            best_qvalues, best_match, min_qvalue, match_fraction, match_any  = helper.match_hits_to_ground_truth(file_path, motifs, size)

            intermediate = keras.Model(inputs=model.inputs, outputs=model.layers[layer].output)
            fmap = intermediate.predict(x_test)
            W = explain.activation_pwm(fmap, x_test, threshold=threshold, window=window)

            num_filters = len(W)
            num_widths = int(np.ceil(num_filters/num_cols))

            fig = plt.figure(figsize=figsize)
            fig.subplots_adjust(hspace=0.3, wspace=0.3)

            for n, w in enumerate(W):
                ax = fig.add_subplot(num_widths, num_cols, n+1)
                #if (np.sum(w) != 0) | (np.sum(np.isnan(w) == True) > 0):

                # calculate sequence logo heights
                I = np.log2(4) + np.sum(w * np.log2(w+1e-7), axis=1, keepdims=True)
                logo = I*w

                L, A = w.shape
                counts_df = pd.DataFrame(data=0.0, columns=list('ACGT'), index=list(range(L)))
                for a in range(A):
                    for l in range(L):
                        counts_df.iloc[l,a] = logo[l,a]

                logomaker.Logo(counts_df, ax=ax)
                ax = plt.gca()
                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                ax.yaxis.set_ticks_position('none')
                ax.xaxis.set_ticks_position('none')
                plt.xticks([])
                plt.yticks([])
                plt.ylabel(motifnames[int(best_match[n])], fontsize=16)

            outfile = os.path.join(save_path, 'label_'+name+'.pdf')
            fig.savefig(outfile, format='pdf', dpi=200, bbox_inches='tight')
            plt.close()