In [1]:
import numpy as np
import scipy

import pandas as pd

import random, os, h5py, math, time, glob

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

import sklearn
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.preprocessing import OneHotEncoder

import keras
import keras.backend as K
from keras.layers import GRU, CuDNNGRU, Dense, Lambda, Dropout, Input, Embedding, Flatten
from keras import Model
import keras.optimizers
from keras.models import load_model

import tensorflow as tf

#tf.disable_v2_behavior()

from apa_predictor_pytorch import *

class IdentityEncoder :
    
    def __init__(self, seq_len, channel_map) :
        self.seq_len = seq_len
        self.n_channels = len(channel_map)
        self.encode_map = channel_map
        self.decode_map = {
            nt: ix for ix, nt in self.encode_map.items()
        }
    
    def encode(self, seq) :
        encoding = np.zeros((self.seq_len, self.n_channels))
        
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.

        return encoding
    
    def encode_inplace(self, seq, encoding) :
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.
    
    def encode_inplace_sparse(self, seq, encoding_mat, row_index) :
        raise NotImplementError()
    
    def decode(self, encoding) :
        seq = ''
    
        for pos in range(0, encoding.shape[0]) :
            argmax_nt = np.argmax(encoding[pos, :])
            max_nt = np.max(encoding[pos, :])
            seq += self.decode_map[argmax_nt]

        return seq
    
    def decode_sparse(self, encoding_mat, row_index) :
        raise NotImplementError()


Using TensorFlow backend.


In [2]:
#Load pytorch APA model skeleton
analyzer = APAClassifier(run_name='aparent_pytorch', seq_len=205)


[*] Checkpoint 15 found!


In [3]:
#Load APARENT Keras predictor model

#Specfiy file path to pre-trained predictor network

save_dir = os.path.join(os.getcwd(), '../aparent/saved_models')
saved_predictor_model_name = 'aparent_plasmid_iso_cut_distalpas_all_libs_no_sampleweights_sgd.h5'
saved_predictor_model_path = os.path.join(save_dir, saved_predictor_model_name)

saved_predictor = load_model(saved_predictor_model_path)

acgt_encoder = IdentityEncoder(205, {'A':0, 'C':1, 'G':2, 'T':3})




In [4]:
saved_predictor.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 205, 4, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 198, 1, 96)   3168        input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 99, 1, 96)    0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 94, 1, 128)   73856       max_pooling2d_1[0][0]            
____________________________________________________________________________________________

In [5]:
#Collect weights from keras model

conv_1_weight, conv_1_bias = saved_predictor.get_layer('conv2d_1').get_weights()
conv_2_weight, conv_2_bias = saved_predictor.get_layer('conv2d_2').get_weights()

dense_1_weight, dense_1_bias = saved_predictor.get_layer('dense_1').get_weights()
dense_iso_weight, dense_iso_bias = saved_predictor.get_layer('dense_3').get_weights()


In [6]:
#Manually transfer model weights from keras to pytorch

with torch.no_grad() :
    analyzer.cnn.conv1.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_1_weight, (3, 1, 2, 0))))
    analyzer.cnn.conv1.bias = nn.Parameter(torch.FloatTensor(conv_1_bias))
    
    analyzer.cnn.conv2.weight = nn.Parameter(torch.FloatTensor(np.transpose(conv_2_weight, (3, 2, 1, 0))))
    analyzer.cnn.conv2.bias = nn.Parameter(torch.FloatTensor(conv_2_bias))
    
    analyzer.cnn.fc1.weight = nn.Parameter(torch.FloatTensor(np.transpose(dense_1_weight, (1, 0))))
    analyzer.cnn.fc1.bias = nn.Parameter(torch.FloatTensor(dense_1_bias))
    
    analyzer.cnn.fc2.weight = nn.Parameter(torch.FloatTensor(np.transpose(dense_iso_weight, (1, 0))))
    analyzer.cnn.fc2.bias = nn.Parameter(torch.FloatTensor(dense_iso_bias))

analyzer.save_model(epoch=15)


In [4]:
#Reload pytorch model and compare predict function to keras model

analyzer = APAClassifier(run_name='aparent_pytorch', seq_len=205)


[*] Checkpoint 15 found!


In [5]:

n_seqs_to_test = 64

sequence_template = 'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG'

#Build random data
random_seqs = [
    ''.join([
        sequence_template[j] if sequence_template[j] != 'N' else np.random.choice(['A', 'C', 'G', 'T'])
        for j in range(len(sequence_template))
    ]) for i in range(n_seqs_to_test)
]

onehots_random = np.expand_dims(np.concatenate([
    np.expand_dims(acgt_encoder.encode(rand_seq), axis=0) for rand_seq in random_seqs
], axis=0), axis=-1)

fake_lib = np.zeros((n_seqs_to_test, 13))
fake_lib[:, 5] = 1.
fake_d = np.ones((n_seqs_to_test, 1))


In [6]:
#Predict fitness using keras model
iso_random_keras, _ = saved_predictor.predict(x=[onehots_random, fake_lib, fake_d], batch_size=32)
prob_random_keras = np.ravel(iso_random_keras)

#Predict fitness using pytorch model
iso_random_pytorch = analyzer.predict_model(random_seqs)
prob_random_pytorch = np.ravel(iso_random_pytorch)




In [7]:

for i, [p_keras, p_pytorch] in enumerate(zip(prob_random_keras.tolist(), prob_random_pytorch.tolist())) :
    print("--------------------")
    print("Sequence " + str(i))
    print("prob (keras) = " + str(round(p_keras, 4)))
    print("prob (pytorch) = " + str(round(p_pytorch, 4)))


--------------------
Sequence 0
prob (keras) = 0.2001
prob (pytorch) = 0.2001
--------------------
Sequence 1
prob (keras) = 0.2159
prob (pytorch) = 0.2159
--------------------
Sequence 2
prob (keras) = 0.4266
prob (pytorch) = 0.4266
--------------------
Sequence 3
prob (keras) = 0.1174
prob (pytorch) = 0.1174
--------------------
Sequence 4
prob (keras) = 0.387
prob (pytorch) = 0.387
--------------------
Sequence 5
prob (keras) = 0.1961
prob (pytorch) = 0.1961
--------------------
Sequence 6
prob (keras) = 0.0945
prob (pytorch) = 0.0945
--------------------
Sequence 7
prob (keras) = 0.3683
prob (pytorch) = 0.3683
--------------------
Sequence 8
prob (keras) = 0.1432
prob (pytorch) = 0.1432
--------------------
Sequence 9
prob (keras) = 0.3353
prob (pytorch) = 0.3353
--------------------
Sequence 10
prob (keras) = 0.2542
prob (pytorch) = 0.2542
--------------------
Sequence 11
prob (keras) = 0.1026
prob (pytorch) = 0.1026
--------------------
Sequence 12
prob (keras) = 0.1937
prob (pyt