In [1]:
import numpy as np
import scipy

import pandas as pd

import random, os, h5py, math, time, glob

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

import sklearn
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.preprocessing import OneHotEncoder

import keras
import keras.backend as K
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras import Model
import keras.optimizers
from keras.models import Sequential, Model, load_model

import tensorflow as tf

#tf.disable_v2_behavior()

from mpradragonn_predictor_pytorch import *

class IdentityEncoder :
    
    def __init__(self, seq_len, channel_map) :
        self.seq_len = seq_len
        self.n_channels = len(channel_map)
        self.encode_map = channel_map
        self.decode_map = {
            nt: ix for ix, nt in self.encode_map.items()
        }
    
    def encode(self, seq) :
        encoding = np.zeros((self.seq_len, self.n_channels))
        
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.

        return encoding
    
    def encode_inplace(self, seq, encoding) :
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.
    
    def encode_inplace_sparse(self, seq, encoding_mat, row_index) :
        raise NotImplementError()
    
    def decode(self, encoding) :
        seq = ''
    
        for pos in range(0, encoding.shape[0]) :
            argmax_nt = np.argmax(encoding[pos, :])
            max_nt = np.max(encoding[pos, :])
            seq += self.decode_map[argmax_nt]

        return seq
    
    def decode_sparse(self, encoding_mat, row_index) :
        raise NotImplementError()


Using TensorFlow backend.


In [2]:
#Load pytorch MPRA-DragoNN model skeleton
analyzer = DragoNNClassifier(run_name='mpradragonn_pytorch', seq_len=145)


[*] Checkpoint 10 found!


In [3]:
#Load MPRA-DragoNN Keras predictor model

#Specfiy file path to pre-trained predictor network

def load_data(data_name, valid_set_size=0.05, test_set_size=0.05) :
    
    #Load cached dataframe
    cached_dict = pickle.load(open(data_name, 'rb'))
    x_train = cached_dict['x_train']
    y_train = cached_dict['y_train']
    x_test = cached_dict['x_test']
    y_test = cached_dict['y_test']
    
    x_train = np.moveaxis(x_train, 3, 1)
    x_test = np.moveaxis(x_test, 3, 1)
    
    return x_train, x_test

def load_predictor_model(model_path) :

    saved_model = Sequential()

    # sublayer 1
    saved_model.add(Conv1D(48, 3, padding='same', activation='relu', input_shape=(145, 4), name='dragonn_conv1d_1_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_1_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_1_copy'))

    saved_model.add(Conv1D(64, 3, padding='same', activation='relu', name='dragonn_conv1d_2_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_2_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_2_copy'))

    saved_model.add(Conv1D(100, 3, padding='same', activation='relu', name='dragonn_conv1d_3_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_3_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_3_copy'))

    saved_model.add(Conv1D(150, 7, padding='same', activation='relu', name='dragonn_conv1d_4_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_4_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_4_copy'))

    saved_model.add(Conv1D(300, 7, padding='same', activation='relu', name='dragonn_conv1d_5_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_5_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_5_copy'))

    saved_model.add(MaxPooling1D(3))

    # sublayer 2
    saved_model.add(Conv1D(200, 7, padding='same', activation='relu', name='dragonn_conv1d_6_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_6_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_6_copy'))

    saved_model.add(Conv1D(200, 3, padding='same', activation='relu', name='dragonn_conv1d_7_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_7_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_7_copy'))

    saved_model.add(Conv1D(200, 3, padding='same', activation='relu', name='dragonn_conv1d_8_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_8_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_8_copy'))

    saved_model.add(MaxPooling1D(4))

    # sublayer 3
    saved_model.add(Conv1D(200, 7, padding='same', activation='relu', name='dragonn_conv1d_9_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_9_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_9_copy'))

    saved_model.add(MaxPooling1D(4))

    saved_model.add(Flatten())
    saved_model.add(Dense(100, activation='relu', name='dragonn_dense_1_copy'))
    saved_model.add(BatchNormalization(name='dragonn_batchnorm_10_copy'))
    saved_model.add(Dropout(0.1, name='dragonn_dropout_10_copy'))
    saved_model.add(Dense(12, activation='linear', name='dragonn_dense_2_copy'))

    saved_model.compile(
        loss= "mean_squared_error",
        optimizer=keras.optimizers.SGD(lr=0.1)
    )

    saved_model.load_weights(model_path)
    
    return saved_model


#Specfiy file path to pre-trained predictor network

saved_predictor_model_path = '../seqprop/examples/mpradragonn/pretrained_deep_factorized_model.hdf5'

saved_predictor = load_predictor_model(saved_predictor_model_path)

acgt_encoder = IdentityEncoder(145, {'A':0, 'C':1, 'G':2, 'T':3})

#Get latent space predictor
saved_predictor_w_dense = Model(
    inputs = saved_predictor.inputs,
    outputs = saved_predictor.outputs + [saved_predictor.get_layer('dragonn_dropout_1_copy').output]
)
saved_predictor_w_dense.compile(loss='mse', optimizer=keras.optimizers.SGD(lr=0.1))



In [4]:
saved_predictor.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dragonn_conv1d_1_copy (Conv1 (None, 145, 48)           624       
_________________________________________________________________
dragonn_batchnorm_1_copy (Ba (None, 145, 48)           192       
_________________________________________________________________
dragonn_dropout_1_copy (Drop (None, 145, 48)           0         
_________________________________________________________________
dragonn_conv1d_2_copy (Conv1 (None, 145, 64)           9280      
_________________________________________________________________
dragonn_batchnorm_2_copy (Ba (None, 145, 64)           256       
_________________________________________________________________
dragonn_dropout_2_copy (Drop (None, 145, 64)           0         
_________________________________________________________________
dragonn_conv1d_3_copy (Conv1 (None, 145, 100)         

In [5]:
#Collect weights from keras model

conv_1_weight, conv_1_bias = saved_predictor.get_layer('dragonn_conv1d_1_copy').get_weights()
conv_1_weight = np.expand_dims(conv_1_weight, axis=1)
gamma_1, beta_1, moving_mean_1, moving_var_1 = saved_predictor.get_layer('dragonn_batchnorm_1_copy').get_weights()

conv_2_weight, conv_2_bias = saved_predictor.get_layer('dragonn_conv1d_2_copy').get_weights()
conv_2_weight = np.expand_dims(conv_2_weight, axis=1)
gamma_2, beta_2, moving_mean_2, moving_var_2 = saved_predictor.get_layer('dragonn_batchnorm_2_copy').get_weights()

conv_3_weight, conv_3_bias = saved_predictor.get_layer('dragonn_conv1d_3_copy').get_weights()
conv_3_weight = np.expand_dims(conv_3_weight, axis=1)
gamma_3, beta_3, moving_mean_3, moving_var_3 = saved_predictor.get_layer('dragonn_batchnorm_3_copy').get_weights()

conv_4_weight, conv_4_bias = saved_predictor.get_layer('dragonn_conv1d_4_copy').get_weights()
conv_4_weight = np.expand_dims(conv_4_weight, axis=1)
gamma_4, beta_4, moving_mean_4, moving_var_4 = saved_predictor.get_layer('dragonn_batchnorm_4_copy').get_weights()

conv_5_weight, conv_5_bias = saved_predictor.get_layer('dragonn_conv1d_5_copy').get_weights()
conv_5_weight = np.expand_dims(conv_5_weight, axis=1)
gamma_5, beta_5, moving_mean_5, moving_var_5 = saved_predictor.get_layer('dragonn_batchnorm_5_copy').get_weights()


conv_6_weight, conv_6_bias = saved_predictor.get_layer('dragonn_conv1d_6_copy').get_weights()
conv_6_weight = np.expand_dims(conv_6_weight, axis=1)
gamma_6, beta_6, moving_mean_6, moving_var_6 = saved_predictor.get_layer('dragonn_batchnorm_6_copy').get_weights()

conv_7_weight, conv_7_bias = saved_predictor.get_layer('dragonn_conv1d_7_copy').get_weights()
conv_7_weight = np.expand_dims(conv_7_weight, axis=1)
gamma_7, beta_7, moving_mean_7, moving_var_7 = saved_predictor.get_layer('dragonn_batchnorm_7_copy').get_weights()

conv_8_weight, conv_8_bias = saved_predictor.get_layer('dragonn_conv1d_8_copy').get_weights()
conv_8_weight = np.expand_dims(conv_8_weight, axis=1)
gamma_8, beta_8, moving_mean_8, moving_var_8 = saved_predictor.get_layer('dragonn_batchnorm_8_copy').get_weights()


conv_9_weight, conv_9_bias = saved_predictor.get_layer('dragonn_conv1d_9_copy').get_weights()
conv_9_weight = np.expand_dims(conv_9_weight, axis=1)
gamma_9, beta_9, moving_mean_9, moving_var_9 = saved_predictor.get_layer('dragonn_batchnorm_9_copy').get_weights()


dense_10_weight, dense_10_bias = saved_predictor.get_layer('dragonn_dense_1_copy').get_weights()
gamma_10, beta_10, moving_mean_10, moving_var_10 = saved_predictor.get_layer('dragonn_batchnorm_10_copy').get_weights()

dense_11_weight, dense_11_bias = saved_predictor.get_layer('dragonn_dense_2_copy').get_weights()



In [6]:

print(conv_1_weight.shape)
print(conv_1_bias.shape)
print("----------")
print(beta_1.shape)
print(gamma_1.shape)
print(moving_mean_1.shape)
print(moving_var_1.shape)
print("----------")
print(conv_2_weight.shape)
print(conv_2_bias.shape)
print("----------")
print(beta_2.shape)
print(gamma_2.shape)
print(moving_mean_2.shape)
print(moving_var_2.shape)


(3, 1, 4, 48)
(48,)
----------
(48,)
(48,)
(48,)
(48,)
----------
(3, 1, 48, 64)
(64,)
----------
(64,)
(64,)
(64,)
(64,)


In [7]:

print(analyzer.cnn.conv1.weight.shape)
print(analyzer.cnn.conv1.bias.shape)
print("----------")
print(analyzer.cnn.norm1.bias.shape)
print(analyzer.cnn.norm1.weight.shape)
print(analyzer.cnn.norm1.running_mean.shape)
print(analyzer.cnn.norm1.running_var.shape)
print("----------")
print(analyzer.cnn.conv2.weight.shape)
print(analyzer.cnn.conv2.bias.shape)
print("----------")
print(analyzer.cnn.norm2.bias.shape)
print(analyzer.cnn.norm2.weight.shape)
print(analyzer.cnn.norm2.running_mean.shape)
print(analyzer.cnn.norm2.running_var.shape)


torch.Size([48, 4, 1, 3])
torch.Size([48])
----------
torch.Size([48])
torch.Size([48])
torch.Size([48])
torch.Size([48])
----------
torch.Size([64, 48, 1, 3])
torch.Size([64])
----------
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


In [8]:
#Manually transfer model weights from keras to pytorch

with torch.no_grad() :
    analyzer.cnn.conv1.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_1_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv1.bias = nn.Parameter(torch.FloatTensor(conv_1_bias))
    analyzer.cnn.norm1.bias = nn.Parameter(torch.FloatTensor(beta_1))
    analyzer.cnn.norm1.weight = nn.Parameter(torch.FloatTensor(gamma_1))
    analyzer.cnn.norm1.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_1))
    analyzer.cnn.norm1.running_var = nn.Parameter(torch.FloatTensor(moving_var_1))
    
    analyzer.cnn.conv2.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_2_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv2.bias = nn.Parameter(torch.FloatTensor(conv_2_bias))
    analyzer.cnn.norm2.bias = nn.Parameter(torch.FloatTensor(beta_2))
    analyzer.cnn.norm2.weight = nn.Parameter(torch.FloatTensor(gamma_2))
    analyzer.cnn.norm2.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_2))
    analyzer.cnn.norm2.running_var = nn.Parameter(torch.FloatTensor(moving_var_2))
    
    analyzer.cnn.conv3.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_3_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv3.bias = nn.Parameter(torch.FloatTensor(conv_3_bias))
    analyzer.cnn.norm3.bias = nn.Parameter(torch.FloatTensor(beta_3))
    analyzer.cnn.norm3.weight = nn.Parameter(torch.FloatTensor(gamma_3))
    analyzer.cnn.norm3.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_3))
    analyzer.cnn.norm3.running_var = nn.Parameter(torch.FloatTensor(moving_var_3))
    
    analyzer.cnn.conv4.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_4_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv4.bias = nn.Parameter(torch.FloatTensor(conv_4_bias))
    analyzer.cnn.norm4.bias = nn.Parameter(torch.FloatTensor(beta_4))
    analyzer.cnn.norm4.weight = nn.Parameter(torch.FloatTensor(gamma_4))
    analyzer.cnn.norm4.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_4))
    analyzer.cnn.norm4.running_var = nn.Parameter(torch.FloatTensor(moving_var_4))
    
    analyzer.cnn.conv5.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_5_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv5.bias = nn.Parameter(torch.FloatTensor(conv_5_bias))
    analyzer.cnn.norm5.bias = nn.Parameter(torch.FloatTensor(beta_5))
    analyzer.cnn.norm5.weight = nn.Parameter(torch.FloatTensor(gamma_5))
    analyzer.cnn.norm5.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_5))
    analyzer.cnn.norm5.running_var = nn.Parameter(torch.FloatTensor(moving_var_5))
    
    analyzer.cnn.conv6.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_6_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv6.bias = nn.Parameter(torch.FloatTensor(conv_6_bias))
    analyzer.cnn.norm6.bias = nn.Parameter(torch.FloatTensor(beta_6))
    analyzer.cnn.norm6.weight = nn.Parameter(torch.FloatTensor(gamma_6))
    analyzer.cnn.norm6.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_6))
    analyzer.cnn.norm6.running_var = nn.Parameter(torch.FloatTensor(moving_var_6))
    
    analyzer.cnn.conv7.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_7_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv7.bias = nn.Parameter(torch.FloatTensor(conv_7_bias))
    analyzer.cnn.norm7.bias = nn.Parameter(torch.FloatTensor(beta_7))
    analyzer.cnn.norm7.weight = nn.Parameter(torch.FloatTensor(gamma_7))
    analyzer.cnn.norm7.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_7))
    analyzer.cnn.norm7.running_var = nn.Parameter(torch.FloatTensor(moving_var_7))
    
    analyzer.cnn.conv8.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_8_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv8.bias = nn.Parameter(torch.FloatTensor(conv_8_bias))
    analyzer.cnn.norm8.bias = nn.Parameter(torch.FloatTensor(beta_8))
    analyzer.cnn.norm8.weight = nn.Parameter(torch.FloatTensor(gamma_8))
    analyzer.cnn.norm8.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_8))
    analyzer.cnn.norm8.running_var = nn.Parameter(torch.FloatTensor(moving_var_8))
    
    analyzer.cnn.conv9.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_9_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv9.bias = nn.Parameter(torch.FloatTensor(conv_9_bias))
    analyzer.cnn.norm9.bias = nn.Parameter(torch.FloatTensor(beta_9))
    analyzer.cnn.norm9.weight = nn.Parameter(torch.FloatTensor(gamma_9))
    analyzer.cnn.norm9.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_9))
    analyzer.cnn.norm9.running_var = nn.Parameter(torch.FloatTensor(moving_var_9))
    
    analyzer.cnn.fc10.weight = nn.Parameter(torch.FloatTensor(np.transpose(dense_10_weight, (1, 0))))
    analyzer.cnn.fc10.bias = nn.Parameter(torch.FloatTensor(dense_10_bias))
    analyzer.cnn.norm10.bias = nn.Parameter(torch.FloatTensor(beta_10))
    analyzer.cnn.norm10.weight = nn.Parameter(torch.FloatTensor(gamma_10))
    analyzer.cnn.norm10.running_mean = nn.Parameter(torch.FloatTensor(moving_mean_10))
    analyzer.cnn.norm10.running_var = nn.Parameter(torch.FloatTensor(moving_var_10))
    
    analyzer.cnn.fc11.weight = nn.Parameter(torch.FloatTensor(np.transpose(dense_11_weight, (1, 0))))
    analyzer.cnn.fc11.bias = nn.Parameter(torch.FloatTensor(dense_11_bias))

analyzer.save_model(epoch=10)


In [9]:
#Reload pytorch model and compare predict function to keras model

analyzer = DragoNNClassifier(run_name='mpradragonn_pytorch', seq_len=145)


[*] Checkpoint 10 found!


In [10]:

n_seqs_to_test = 64

sequence_template = 'N' * 145

#Build random data
random_seqs = [
    ''.join([
        sequence_template[j] if sequence_template[j] != 'N' else np.random.choice(['A', 'C', 'G', 'T'])
        for j in range(len(sequence_template))
    ]) for i in range(n_seqs_to_test)
]

onehots_random = np.concatenate([
    np.expand_dims(acgt_encoder.encode(rand_seq), axis=0) for rand_seq in random_seqs
], axis=0)


In [11]:
#Predict fitness using keras model
prob_random_keras, debug_keras = saved_predictor_w_dense.predict(x=[onehots_random], batch_size=32)
prob_random_keras = np.ravel(prob_random_keras[:, 5])

#Predict fitness using pytorch model
prob_random_pytorch = analyzer.predict_model(random_seqs)
prob_random_pytorch = np.ravel(prob_random_pytorch)


In [14]:

for i, [p_keras, p_pytorch] in enumerate(zip(prob_random_keras.tolist(), prob_random_pytorch.tolist())) :
    print("--------------------")
    print("Sequence " + str(i))
    print("prob (keras) = " + str(round(p_keras, 4)))
    print("prob (pytorch) = " + str(round(p_pytorch, 4)))


--------------------
Sequence 0
prob (keras) = -0.2048
prob (pytorch) = -0.2048
--------------------
Sequence 1
prob (keras) = 0.0621
prob (pytorch) = 0.0621
--------------------
Sequence 2
prob (keras) = -0.1181
prob (pytorch) = -0.1181
--------------------
Sequence 3
prob (keras) = -0.1441
prob (pytorch) = -0.1441
--------------------
Sequence 4
prob (keras) = 0.1855
prob (pytorch) = 0.1855
--------------------
Sequence 5
prob (keras) = -0.2105
prob (pytorch) = -0.2105
--------------------
Sequence 6
prob (keras) = 0.0355
prob (pytorch) = 0.0355
--------------------
Sequence 7
prob (keras) = -0.0108
prob (pytorch) = -0.0108
--------------------
Sequence 8
prob (keras) = -0.1303
prob (pytorch) = -0.1303
--------------------
Sequence 9
prob (keras) = -0.1204
prob (pytorch) = -0.1204
--------------------
Sequence 10
prob (keras) = -0.1226
prob (pytorch) = -0.1226
--------------------
Sequence 11
prob (keras) = -0.1951
prob (pytorch) = -0.1951
--------------------
Sequence 12
prob (keras