In [1]:
import sys
import os
import itertools
from keras.layers import Input, Dense, Reshape, Flatten
from keras import layers, initializers
from keras.models import Model, load_model
import keras.backend as K
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
from scipy.stats import norm
from scipy.optimize import minimize
from keras.utils.generic_utils import get_custom_objects
import json
#import tensorflow_probability as tfp

#tfd = tfp.distributions

from keras.backend.tensorflow_backend import set_session

def contain_tf_gpu_mem_usage() :
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    set_session(sess)

contain_tf_gpu_mem_usage()


Using TensorFlow backend.


In [2]:
from tensorflow.python.framework import ops

#Stochastic Binarized Neuron helper functions (Tensorflow)
#ST Estimator code adopted from https://r2rt.com/beyond-binary-ternary-and-one-hot-neurons.html
#See Github https://github.com/spitis/

def st_sampled_softmax(logits):
    with ops.name_scope("STSampledSoftmax") as namescope :
        nt_probs = tf.nn.softmax(logits)
        onehot_dim = logits.get_shape().as_list()[1]
        sampled_onehot = tf.one_hot(tf.squeeze(tf.multinomial(logits, 1), 1), onehot_dim, 1.0, 0.0)
        with tf.get_default_graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
            return tf.ceil(sampled_onehot * nt_probs)

def st_hardmax_softmax(logits):
    with ops.name_scope("STHardmaxSoftmax") as namescope :
        nt_probs = tf.nn.softmax(logits)
        onehot_dim = logits.get_shape().as_list()[1]
        sampled_onehot = tf.one_hot(tf.argmax(nt_probs, 1), onehot_dim, 1.0, 0.0)
        with tf.get_default_graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
            return tf.ceil(sampled_onehot * nt_probs)

@ops.RegisterGradient("STMul")
def st_mul(op, grad):
    return [grad, grad]


In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
#Load models

vae_path = "../vae/saved_models/vae_apa_max_isoform_simple_strong_len_96_latent_100_epochs_50_kl_factor_1125_annealed"

predictor_path = '../../../../aparent/saved_models/aparent_plasmid_iso_cut_distalpas_all_libs_no_sampleweights_sgd.h5'


In [5]:
import isolearn.io as isoio
import isolearn.keras as isol

encoder = isol.OneHotEncoder(seq_length=205)

def initialize_sequence_templates(sequence_template, encoder=encoder) :

    onehot_template = encoder(sequence_template).reshape((1, len(sequence_template), 4, 1))

    for j in range(len(sequence_template)) :
        if sequence_template[j] != 'N' :
            nt_ix = np.argmax(onehot_template[0, j, :, 0])
            onehot_template[0, j, :, :] = 0
            onehot_template[0, j, nt_ix, :] = 1
        else :
            onehot_template[0, j, :, :] = 0

    onehot_mask = np.zeros((1, len(sequence_template), 4, 1))
    for j in range(len(sequence_template)) :
        if sequence_template[j] == 'N' :
            onehot_mask[0, j, :, :] = 1.0

    return onehot_template, onehot_mask

sequence_template = 'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG'

template_mat, mask_mat = initialize_sequence_templates(sequence_template)


In [13]:

def killoran_opt(vae_decoder, oracle,
                 steps=20000, store_every1=5, store_every2=100, store_swap_iter=100, epsilon1=10**-5, epsilon2=1, noise_std=10**-5, save_path=None,
                 LD=100, verbose=False, adam=False, template_mat=template_mat, mask_mat=mask_mat, encoder=encoder):
    
    G = vae_decoder
    f = oracle
    
    sess = K.get_session()
    zt = K.tf.Variable(np.random.normal(size=[1, LD]), dtype='float32')
    zt_dummy = K.tf.Variable(np.zeros((1, 1)), trainable=False, dtype='float32')
    
    lib_simple = np.zeros((1, 13))
    lib_simple[0, 5] = 1.
    
    pred_input = K.tf.Variable(np.zeros((1, 205, 4, 1)), dtype='float32')
    lib_input = K.tf.Variable(lib_simple, trainable=False, dtype='float32')
    lib_distal = K.tf.Variable(np.ones((1, 1)), trainable=False, dtype='float32')
    
    template = K.tf.Variable(template_mat, trainable=False, dtype='float32')
    mask = K.tf.Variable(mask_mat, trainable=False, dtype='float32')
    
    left_pad = K.tf.Variable(np.zeros((1, 25, 4, 1)), trainable=False, dtype='float32')
    right_pad = K.tf.Variable(np.zeros((1, 84, 4, 1)), trainable=False, dtype='float32')
    
    gen_output = K.tf.concat([left_pad, K.tf.transpose(G([zt_dummy, zt])[1], (0, 2, 3, 1)), right_pad], axis=1) * mask + template
    #prior = tfd.Normal(0, 1)
    #p_z = prior.log_prob(zt)
    
    predictions = f([pred_input, lib_input, lib_distal])[0][0, 0]
    update_pred_input = K.tf.assign(pred_input, gen_output)
    dfdx = K.tf.gradients(ys=-predictions, xs=pred_input)[0]
    dfdz = K.tf.gradients(gen_output, zt, grad_ys=dfdx)[0]
    #dpz = K.tf.gradients(p_z, zt)[0]
    
    noise = K.tf.random_normal(shape=[1, LD], stddev=noise_std)
    eps1 = K.tf.Variable(epsilon1, trainable=False)
    eps2 = K.tf.Variable(epsilon2, trainable=False)
    if adam:
        optimizer = K.tf.train.AdamOptimizer(learning_rate=epsilon2)
        step = dfdz + noise
    else:
        optimizer = K.tf.train.GradientDescentOptimizer(learning_rate=1)
        step = eps1 * dpz + eps2 * dfdz + noise
    
    design_op = optimizer.apply_gradients([(step, zt)])
    adam_initializers = [var.initializer for var in K.tf.global_variables() if 'Adam' in var.name or 'beta' in var.name]
    sess.run(adam_initializers)
    sess.run(pred_input.initializer)
    sess.run(zt.initializer)
    sess.run(eps1.initializer)
    sess.run(eps2.initializer)
    
    sess.run(zt_dummy.initializer)
    sess.run(pred_input.initializer)
    sess.run(lib_input.initializer)
    sess.run(lib_distal.initializer)
    sess.run(template.initializer)
    sess.run(mask.initializer)
    sess.run(left_pad.initializer)
    sess.run(right_pad.initializer)
    
    s = sess.run(K.tf.shape(zt))
    sess.run(update_pred_input, {
        zt: np.random.normal(size=s),
        zt_dummy: np.zeros((1, 1)),
        lib_input: lib_simple,
        lib_distal: np.ones((1, 1)),
        template: template_mat,
        mask: mask_mat,
        left_pad: np.zeros((1, 25, 4, 1)),
        right_pad: np.zeros((1, 84, 4, 1))
    })
    z_0 = sess.run([zt])
    
    store_every = store_every1
    
    xt_prev = None
    for t in range(steps):
        if t % 1000 == 0 :
            print("Running step " + str(t) + "...")
        
        if t > store_swap_iter :
            store_every = store_every2
        
        xt0, _, = sess.run([gen_output, design_op], {eps1: epsilon1, eps2:epsilon2})
        pred_in, preds = sess.run([update_pred_input, predictions])
        
        nt_map_inv = {0:'A', 1:'C', 2:'G', 3:'T'}
        
        xt_seq = ''
        for j in range(xt0.shape[1]) :
            argmax_j = np.argmax(xt0[0, j, :, 0])
            xt_seq += nt_map_inv[argmax_j]
        
        if save_path is not None and t % store_every == 0 :
            with open(save_path + "_iter_" + str(t) + ".txt", "a+") as f :
                f.write(xt_seq + "\n")


In [14]:
import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import regularizers
from keras import backend as K
import keras.losses

def _load_aparent_func(model_path) :

    seq_input = Input(shape=(205, 4, 1), name='seq_input')
    lib_input = Input(shape=(13,), name='lib_input')
    distal_pas_input = Input(shape=(1,), name='distal_pas_input')
    
    #Shared model definition
    layer_1 = Conv2D(96, (8, 4), padding='valid', activation='relu', name='aparent_conv_1')
    layer_1_pool = MaxPooling2D(pool_size=(2, 1))
    layer_2 = Conv2D(128, (6, 1), padding='valid', activation='relu', name='aparent_conv_2')
    layer_dense = Dense(256, activation='relu', name='aparent_dense_1')#(Concatenate()([Flatten()(layer_2), distal_pas_input]))
    layer_drop = Dropout(0.2)

    def shared_model(seq_input, distal_pas_input) :
        return layer_drop(
            layer_dense(
                Concatenate()([
                    Flatten()(
                        layer_2(
                            layer_1_pool(
                                layer_1(
                                    seq_input
                                )
                            )
                        )
                    ),
                    distal_pas_input
                ])
            ), training=False
        )


    #Outputs
    plasmid_out_shared = Concatenate()([shared_model(seq_input, distal_pas_input), lib_input])

    plasmid_score_cut = Dense(206, kernel_initializer='zeros', name='aparent_cut_dense')(plasmid_out_shared)
    plasmid_score_iso = Dense(1, kernel_initializer='zeros', name='aparent_iso_dense')(plasmid_out_shared)

    plasmid_out_cut = Softmax(axis=-1)(plasmid_score_cut)
    plasmid_out_iso = Dense(1, activation='sigmoid', kernel_initializer='ones', use_bias=False)(plasmid_score_iso)

    _oracle = Model([seq_input, lib_input, distal_pas_input], [plasmid_score_iso, plasmid_score_cut, plasmid_out_iso, plasmid_out_cut])
    
    _saved_model = load_model(model_path)
    _oracle.get_layer('aparent_conv_1').set_weights(_saved_model.get_layer('conv2d_1').get_weights())
    _oracle.get_layer('aparent_conv_1').trainable = False

    _oracle.get_layer('aparent_conv_2').set_weights(_saved_model.get_layer('conv2d_2').get_weights())
    _oracle.get_layer('aparent_conv_2').trainable = False

    _oracle.get_layer('aparent_dense_1').set_weights(_saved_model.get_layer('dense_1').get_weights())
    _oracle.get_layer('aparent_dense_1').trainable = False

    _oracle.get_layer('aparent_cut_dense').set_weights(_saved_model.get_layer('dense_2').get_weights())
    _oracle.get_layer('aparent_cut_dense').trainable = False

    _oracle.get_layer('aparent_iso_dense').set_weights(_saved_model.get_layer('dense_3').get_weights())
    _oracle.get_layer('aparent_iso_dense').trainable = False

    _oracle.trainable=False
    _oracle.compile(
        optimizer=keras.optimizers.SGD(lr=0.1), loss='mean_squared_error'
    )
    
    return _oracle
    

In [15]:

def run_killoran(n_traj=5, steps=20000, vae_prefix_str="", vae_path=vae_path, predictor_path=predictor_path):
    
    for i in range(n_traj):
        RANDOM_STATE = i+1
        print(RANDOM_STATE)
        
        sess = tf.Session(graph=tf.get_default_graph())
        K.set_session(sess)
        
        #Load models
        #oracle = load_model(predictor_path)
        oracle = _load_aparent_func(predictor_path)

        #encoder = load_model(vae_path + '_encoder.h5', custom_objects={'st_sampled_softmax':st_sampled_softmax, 'st_hardmax_softmax':st_hardmax_softmax, 'min_pred':lambda y_true,y_pred:y_pred})
        decoder = load_model(vae_path + '_decoder.h5', custom_objects={'st_sampled_softmax':st_sampled_softmax, 'st_hardmax_softmax':st_hardmax_softmax, 'min_pred':lambda y_true,y_pred:y_pred})
        
        killoran_opt(decoder, oracle, steps=steps, epsilon1=0., epsilon2=0.1,  
                                     noise_std=1e-6, store_every1=5, store_every2=100, store_swap_iter=100,
                                     LD=100, verbose=False, adam=True,
                                     save_path='killoran/killoran_vae' + vae_prefix_str + '_apa_seqs'
                            )


In [16]:

run_killoran(n_traj=10, steps=2000, vae_prefix_str="_epochs_50_kl_factor_1125")


1
Running step 0...
Running step 1000...
2
Running step 0...
Running step 1000...
3
Running step 0...
Running step 1000...
4
Running step 0...
Running step 1000...
5
Running step 0...
Running step 1000...
6
Running step 0...
Running step 1000...
7
Running step 0...
Running step 1000...
8
Running step 0...
Running step 1000...
9
Running step 0...
Running step 1000...
10
Running step 0...
Running step 1000...
