In [1]:
import sys
import os
import itertools
from keras.layers import Input, Dense, Reshape, Flatten
from keras import layers, initializers
from keras.models import Model, load_model
import keras.backend as K
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
from scipy.stats import norm
from scipy.optimize import minimize
from keras.utils.generic_utils import get_custom_objects
import json
#import tensorflow_probability as tfp

#tfd = tfp.distributions

from keras.backend.tensorflow_backend import set_session

def contain_tf_gpu_mem_usage() :
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    set_session(sess)

contain_tf_gpu_mem_usage()


Using TensorFlow backend.


In [2]:
from tensorflow.python.framework import ops

#Stochastic Binarized Neuron helper functions (Tensorflow)
#ST Estimator code adopted from https://r2rt.com/beyond-binary-ternary-and-one-hot-neurons.html
#See Github https://github.com/spitis/

def st_sampled_softmax(logits):
    with ops.name_scope("STSampledSoftmax") as namescope :
        nt_probs = tf.nn.softmax(logits)
        onehot_dim = logits.get_shape().as_list()[1]
        sampled_onehot = tf.one_hot(tf.squeeze(tf.multinomial(logits, 1), 1), onehot_dim, 1.0, 0.0)
        with tf.get_default_graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
            return tf.ceil(sampled_onehot * nt_probs)

def st_hardmax_softmax(logits):
    with ops.name_scope("STHardmaxSoftmax") as namescope :
        nt_probs = tf.nn.softmax(logits)
        onehot_dim = logits.get_shape().as_list()[1]
        sampled_onehot = tf.one_hot(tf.argmax(nt_probs, 1), onehot_dim, 1.0, 0.0)
        with tf.get_default_graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
            return tf.ceil(sampled_onehot * nt_probs)

@ops.RegisterGradient("STMul")
def st_mul(op, grad):
    return [grad, grad]


In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
import isolearn.io as isoio
import isolearn.keras as isol

encoder = isol.OneHotEncoder(seq_length=205)

def initialize_sequence_templates(sequence_template, encoder=encoder) :

    onehot_template = encoder(sequence_template).reshape((1, len(sequence_template), 4, 1))

    for j in range(len(sequence_template)) :
        if sequence_template[j] != 'N' :
            nt_ix = np.argmax(onehot_template[0, j, :, 0])
            onehot_template[0, j, :, :] = 0
            onehot_template[0, j, nt_ix, :] = 1
        else :
            onehot_template[0, j, :, :] = 0

    onehot_mask = np.zeros((1, len(sequence_template), 4, 1))
    for j in range(len(sequence_template)) :
        if sequence_template[j] == 'N' :
            onehot_mask[0, j, :, :] = 1.0

    return onehot_template, onehot_mask

sequence_template = 'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG'

template_mat, mask_mat = initialize_sequence_templates(sequence_template)


In [5]:
import keras
import tensorflow as tf
from keras.models import Sequential, Model, load_model

from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras import regularizers
from keras import backend as K
from keras.utils.generic_utils import Progbar
from keras.layers.merge import _Merge
import keras.losses

def make_gen_resblock(n_channels=64, window_size=3, stride=1, dilation=1, group_ix=0, layer_ix=0) :

    #Initialize res block layers
    batch_norm_0 = BatchNormalization(name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0')

    relu_0 = Lambda(lambda x: K.relu(x))
    
    deconv_0 = Conv2DTranspose(n_channels, (1, window_size), strides=(1, stride), padding='same', activation='linear', kernel_initializer='glorot_uniform', name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_deconv_0')

    batch_norm_1 = BatchNormalization(name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1')

    relu_1 = Lambda(lambda x: K.relu(x))

    conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=(1, dilation), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_uniform', name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1')

    skip_deconv_0 = Conv2DTranspose(n_channels, (1, 1), strides=(1, stride), padding='same', activation='linear', kernel_initializer='glorot_uniform', name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_deconv_0')
    
    skip_1 = Lambda(lambda x: x[0] + x[1], name='policy_generator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1')

    #Execute res block
    def _resblock_func(input_tensor) :
        
        batch_norm_0_out = batch_norm_0(input_tensor)
        relu_0_out = relu_0(batch_norm_0_out)
        deconv_0_out = deconv_0(relu_0_out)

        batch_norm_1_out = batch_norm_1(deconv_0_out)
        relu_1_out = relu_1(batch_norm_1_out)
        conv_1_out = conv_1(relu_1_out)
        
        skip_deconv_0_out = skip_deconv_0(input_tensor)

        skip_1_out = skip_1([conv_1_out, skip_deconv_0_out])
        
        return skip_1_out

    return _resblock_func

#Decoder Model definition
def load_decoder_resnet(seq_length=96, latent_size=100) :

    #Generator network parameters
    window_size = 3
    
    strides = [2, 2, 2, 2, 1]
    dilations = [1, 1, 1, 1, 1]
    channels = [256, 128, 96, 64, 32]#[384, 256, 128, 64, 32]
    initial_length = 6
    n_resblocks = len(strides)

    #Policy network definition
    policy_dense_0 = Dense(initial_length * channels[0], activation='linear', kernel_initializer='glorot_uniform', name='policy_generator_dense_0')
    policy_dense_0_reshape = Reshape((1, initial_length, channels[0]))
    
    curr_length = initial_length
    
    resblocks = []
    for layer_ix in range(n_resblocks) :
        resblocks.append(make_gen_resblock(n_channels=channels[layer_ix], window_size=window_size, stride=strides[layer_ix], dilation=dilations[layer_ix], group_ix=0, layer_ix=layer_ix))
    
    final_conv = Conv2D(4, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_uniform', name='policy_generator_final_conv')
    
    def _generator_func(seed_input) :
        
        policy_dense_0_out = policy_dense_0_reshape(policy_dense_0(seed_input))
        
        #Connect group of res blocks
        output_tensor = policy_dense_0_out

        #Res block group 0
        for layer_ix in range(n_resblocks) :
            output_tensor = resblocks[layer_ix](output_tensor)

        #Final conv out
        final_conv_out = final_conv(output_tensor)#final_conv(final_relu_out)
        
        return final_conv_out

    return _generator_func


def make_disc_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0) :

    #Initialize res block layers
    batch_norm_0 = BatchNormalization(name='policy_discriminator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0')

    relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0))

    conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0')

    batch_norm_1 = BatchNormalization(name='policy_discriminator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1')

    relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0))

    conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1')

    skip_1 = Lambda(lambda x: x[0] + x[1], name='policy_discriminator_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1')

    #Execute res block
    def _resblock_func(input_tensor) :
        batch_norm_0_out = batch_norm_0(input_tensor)
        relu_0_out = relu_0(batch_norm_0_out)
        conv_0_out = conv_0(relu_0_out)

        batch_norm_1_out = batch_norm_1(conv_0_out)
        relu_1_out = relu_1(batch_norm_1_out)
        conv_1_out = conv_1(relu_1_out)

        skip_1_out = skip_1([conv_1_out, input_tensor])
        
        return skip_1_out

    return _resblock_func

#Encoder Model definition
def load_encoder_network_4_resblocks(batch_size, seq_length=205, latent_size=100, drop_rate=0.25) :

    #Discriminator network parameters
    n_resblocks = 4
    n_channels = 32

    #Discriminator network definition
    policy_conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_conv_0')
    
    skip_conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_skip_conv_0')
    
    resblocks = []
    for layer_ix in range(n_resblocks) :
        resblocks.append(make_disc_resblock(n_channels=n_channels, window_size=8, dilation_rate=1, group_ix=0, layer_ix=layer_ix))
    
    last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_last_block_conv')
    
    skip_add = Lambda(lambda x: x[0] + x[1], name='policy_discriminator_skip_add')
    
    final_flatten = Flatten()
    
    z_mean = Dense(latent_size, name='policy_discriminator_z_mean')
    z_log_var = Dense(latent_size, name='policy_discriminator_z_log_var')
    
    def _encoder_func(sequence_input) :
        policy_conv_0_out = policy_conv_0(sequence_input)

        #Connect group of res blocks
        output_tensor = policy_conv_0_out

        #Res block group 0
        skip_conv_0_out = skip_conv_0(output_tensor)

        for layer_ix in range(n_resblocks) :
            output_tensor = resblocks[layer_ix](output_tensor)
        
        #Last res block extr conv
        last_block_conv_out = last_block_conv(output_tensor)

        skip_add_out = skip_add([last_block_conv_out, skip_conv_0_out])

        #Final dense out
        final_dense_out = final_flatten(skip_add_out)
        
        #Z mean and log variance
        z_mean_out = z_mean(final_dense_out)
        z_log_var_out = z_log_var(final_dense_out)

        return z_mean_out, z_log_var_out

    return _encoder_func

#Encoder Model definition
def load_encoder_network_8_resblocks(batch_size, seq_length=128, drop_rate=0.25) :

    #Discriminator network parameters
    n_resblocks = 4
    n_channels = 32
    latent_size = 100

    #Discriminator network definition
    policy_conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_conv_0')
    
    #Res block group 0
    skip_conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_skip_conv_0')
    
    resblocks_0 = []
    for layer_ix in range(n_resblocks) :
        resblocks_0.append(make_disc_resblock(n_channels=n_channels, window_size=8, dilation_rate=1, group_ix=0, layer_ix=layer_ix))
    
    #Res block group 1
    skip_conv_1 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_skip_conv_1')
    
    resblocks_1 = []
    for layer_ix in range(n_resblocks) :
        resblocks_1.append(make_disc_resblock(n_channels=n_channels, window_size=8, dilation_rate=4, group_ix=1, layer_ix=layer_ix))
    
    last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_discriminator_last_block_conv')
    
    skip_add = Lambda(lambda x: x[0] + x[1] + x[2], name='policy_discriminator_skip_add')
    
    final_flatten = Flatten()
    
    z_mean = Dense(latent_size, name='policy_discriminator_z_mean')
    z_log_var = Dense(latent_size, name='policy_discriminator_z_log_var')
    
    def _encoder_func(sequence_input) :
        policy_conv_0_out = policy_conv_0(sequence_input)

        #Connect group of res blocks
        output_tensor = policy_conv_0_out

        #Res block group 0
        skip_conv_0_out = skip_conv_0(output_tensor)

        for layer_ix in range(n_resblocks) :
            output_tensor = resblocks_0[layer_ix](output_tensor)
        
        #Res block group 0
        skip_conv_1_out = skip_conv_1(output_tensor)

        for layer_ix in range(n_resblocks) :
            output_tensor = resblocks_1[layer_ix](output_tensor)
        
        #Last res block extr conv
        last_block_conv_out = last_block_conv(output_tensor)

        skip_add_out = skip_add([last_block_conv_out, skip_conv_0_out, skip_conv_1_out])

        #Final dense out
        final_dense_out = final_flatten(skip_add_out)
        
        #Z mean and log variance
        z_mean_out = z_mean(final_dense_out)
        z_log_var_out = z_log_var(final_dense_out)

        return z_mean_out, z_log_var_out

    return _encoder_func

#PWM Masking and Sampling helper functions

def mask_pwm(inputs) :
    pwm, onehot_template, onehot_mask = inputs

    return pwm * onehot_mask + onehot_template

def sample_pwm_only(pwm_logits) :
    n_sequences = K.shape(pwm_logits)[0]
    seq_length = K.shape(pwm_logits)[2]

    flat_pwm = K.reshape(pwm_logits, (n_sequences * seq_length, 4))
    sampled_pwm = st_sampled_softmax(flat_pwm)

    return K.reshape(sampled_pwm, (n_sequences, 1, seq_length, 4))

def sample_pwm(pwm_logits) :
    n_sequences = K.shape(pwm_logits)[0]
    seq_length = K.shape(pwm_logits)[2]

    flat_pwm = K.reshape(pwm_logits, (n_sequences * seq_length, 4))
    sampled_pwm = sampled_pwm = K.switch(K.learning_phase(), st_sampled_softmax(flat_pwm), st_hardmax_softmax(flat_pwm))

    return K.reshape(sampled_pwm, (n_sequences, 1, seq_length, 4))

def max_pwm(pwm_logits) :
    n_sequences = K.shape(pwm_logits)[0]
    seq_length = K.shape(pwm_logits)[2]

    flat_pwm = K.reshape(pwm_logits, (n_sequences * seq_length, 4))
    sampled_pwm = sampled_pwm = st_hardmax_softmax(flat_pwm)

    return K.reshape(sampled_pwm, (n_sequences, 1, seq_length, 4))


#Generator helper functions
def initialize_sequence_templates_model(generator, sequence_templates) :

    embedding_templates = []
    embedding_masks = []

    for k in range(len(sequence_templates)) :
        sequence_template = sequence_templates[k]
        onehot_template = isol.OneHotEncoder(seq_length=len(sequence_template))(sequence_template).reshape((1, len(sequence_template), 4))

        for j in range(len(sequence_template)) :
            if sequence_template[j] not in ['N', 'X'] :
                nt_ix = np.argmax(onehot_template[0, j, :])
                onehot_template[:, j, :] = -4.0
                onehot_template[:, j, nt_ix] = 10.0
            elif sequence_template[j] == 'X' :
                onehot_template[:, j, :] = -1.0

        onehot_mask = np.zeros((1, len(sequence_template), 4))
        for j in range(len(sequence_template)) :
            if sequence_template[j] == 'N' :
                onehot_mask[:, j, :] = 1.0

        embedding_templates.append(onehot_template.reshape(1, -1))
        embedding_masks.append(onehot_mask.reshape(1, -1))

    embedding_templates = np.concatenate(embedding_templates, axis=0)
    embedding_masks = np.concatenate(embedding_masks, axis=0)

    generator.get_layer('template_dense').set_weights([embedding_templates])
    generator.get_layer('template_dense').trainable = False

    generator.get_layer('mask_dense').set_weights([embedding_masks])
    generator.get_layer('mask_dense').trainable = False


#Generator construction function
def build_sampler(batch_size, seq_length, n_classes=1, n_samples=None, validation_sample_mode='max') :

    use_samples = True
    if n_samples is None :
        use_samples = False
        n_samples = 1

    
    #Initialize Reshape layer
    reshape_layer = Reshape((1, seq_length, 4))

    #Initialize template and mask matrices
    onehot_template_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='zeros', name='template_dense')
    onehot_mask_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='ones', name='mask_dense')

    #Initialize Templating and Masking Lambda layer
    masking_layer = Lambda(mask_pwm, output_shape = (1, seq_length, 4), name='masking_layer')
    
    #Initialize PWM normalization layer
    pwm_layer = Softmax(axis=-1, name='pwm')
    
    #Initialize sampling layers
    sample_func = sample_pwm
    if validation_sample_mode == 'sample' :
        sample_func = sample_pwm_only
    
    upsampling_layer = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]), name='upsampling_layer')
    sampling_layer = Lambda(sample_func, name='pwm_sampler')
    permute_layer = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, 1, seq_length, 4)), (1, 0, 2, 3, 4)), name='permute_layer')
    
    
    def _sampler_func(class_input, raw_logits) :
        
        #Get Template and Mask
        onehot_template = reshape_layer(onehot_template_dense(class_input))
        onehot_mask = reshape_layer(onehot_mask_dense(class_input))
        
        #Add Template and Multiply Mask
        pwm_logits = masking_layer([raw_logits, onehot_template, onehot_mask])
        
        #Compute PWM (Nucleotide-wise Softmax)
        pwm = pwm_layer(pwm_logits)
        
        sampled_pwm = None
        
        #Optionally tile each PWM to sample from and create sample axis
        if use_samples :
            pwm_logits_upsampled = upsampling_layer(pwm_logits)
            sampled_pwm = sampling_layer(pwm_logits_upsampled)
            sampled_pwm = permute_layer(sampled_pwm)
        else :
            sampled_pwm = sampling_layer(pwm_logits)
        
        
        return pwm_logits, pwm, sampled_pwm
    
    return _sampler_func


def get_pwm_cross_entropy(pwm_start, pwm_end) :

    def _pwm_cross_entropy(inputs) :
        pwm_true, pwm_pred = inputs
        
        pwm_pred = K.clip(pwm_pred, K.epsilon(), 1. - K.epsilon())

        ce = - K.sum(pwm_true[:, 0, pwm_start:pwm_end, :] * K.log(pwm_pred[:, 0, pwm_start:pwm_end, :]), axis=-1)
        
        return K.expand_dims(K.mean(ce, axis=-1), axis=-1)
    
    return _pwm_cross_entropy

def min_pred(y_true, y_pred) :
    return y_pred

def get_weighted_loss(loss_coeff=1.) :
    
    def _min_pred(y_true, y_pred) :
        return loss_coeff * y_pred
    
    return _min_pred

def get_z_sample(z_inputs):
    
    z_mean, z_log_var = z_inputs
    
    batch_size = K.shape(z_mean)[0]
    latent_dim = K.int_shape(z_mean)[1]
    
    epsilon = K.random_normal(shape=(batch_size, latent_dim))
    
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

def get_z_kl_loss(anneal_coeff) :
    
    def _z_kl_loss(inputs, anneal_coeff=anneal_coeff) :
        z_mean, z_log_var = inputs
        
        kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
        kl_loss = K.mean(kl_loss, axis=-1)
        kl_loss *= -0.5
        
        return anneal_coeff * K.expand_dims(kl_loss, axis=-1)
    
    return _z_kl_loss

def build_vae(model_path) :
    
    #Simple Library
    sequence_templates = [
        'N' * 96
    ]

    #Initialize Encoder and Decoder networks
    batch_size = 32
    seq_length = 96
    n_samples = None
    latent_size = 100

    #Load Encoder
    encoder = load_encoder_network_4_resblocks(batch_size, seq_length=seq_length, latent_size=latent_size, drop_rate=0.)

    #Load Decoder
    decoder = load_decoder_resnet(seq_length=seq_length, latent_size=latent_size)

    #Load Sampler
    sampler = build_sampler(batch_size, seq_length, n_classes=1, n_samples=n_samples, validation_sample_mode='sample')

    #Build Encoder Model
    encoder_input = Input(shape=(1, seq_length, 4), name='encoder_input')

    z_mean, z_log_var = encoder(encoder_input)

    z_sampling_layer = Lambda(get_z_sample, output_shape=(latent_size,), name='z_sampler')
    z = z_sampling_layer([z_mean, z_log_var])

    # instantiate encoder model
    encoder_model = Model(encoder_input, [z_mean, z_log_var, z])
    #encoder_model.compile(
    #    optimizer=keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999),
    #    loss=min_pred
    #)

    #Build Decoder Model
    decoder_class = Input(shape=(1,), name='decoder_class')
    decoder_input = Input(shape=(latent_size,), name='decoder_input')

    pwm_logits, pwm, sampled_pwm = sampler(decoder_class, decoder(decoder_input))

    decoder_model = Model([decoder_class, decoder_input], [pwm_logits, pwm, sampled_pwm])

    #Initialize Sequence Templates and Masks
    initialize_sequence_templates_model(decoder_model, sequence_templates)

    #decoder_model.compile(
    #    optimizer=keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999),
    #    loss=min_pred
    #)
    
    vae_decoder_class = Input(shape=(1,), name='vae_decoder_class')
    vae_encoder_input = Input(shape=(1, seq_length, 4), name='vae_encoder_input')

    encoded_z_mean, encoded_z_log_var = encoder(vae_encoder_input)
    encoded_z = z_sampling_layer([encoded_z_mean, encoded_z_log_var])
    decoded_logits, decoded_pwm, decoded_sample = sampler(vae_decoder_class, decoder(encoded_z))

    reconstruction_loss = Lambda(get_pwm_cross_entropy(pwm_start=5, pwm_end=81-5), name='reconstruction')([vae_encoder_input, decoded_pwm])

    anneal_coeff = K.variable(1.0)

    kl_loss = Lambda(get_z_kl_loss(anneal_coeff), name='kl')([encoded_z_mean, encoded_z_log_var])

    vae_model = Model(
        [vae_decoder_class, vae_encoder_input],
        [reconstruction_loss, kl_loss]
    )

    #Initialize Sequence Templates and Masks
    initialize_sequence_templates_model(vae_model, sequence_templates)

    vae_model.compile(
        optimizer=keras.optimizers.Adam(lr=0.0001, beta_1=0.5, beta_2=0.9),
        loss={
            'reconstruction' : get_weighted_loss(loss_coeff=1.),
            'kl' : get_weighted_loss(loss_coeff=1.125)
        }
    )
    
    return vae_model, encoder_model, decoder_model


In [6]:
import scipy.stats

def _templated_predict(oracle, x, batch_size=32, template=template_mat, mask=mask_mat) :
    
    fake_lib = np.zeros((x.shape[0], 13))
    fake_lib[:, 5] = 1.
    fake_d = np.ones((x.shape[0], 1))
    
    left_pad = np.zeros((x.shape[0], 25, 4, 1))
    right_pad = np.zeros((x.shape[0], 84, 4, 1))
    
    onehots = np.concatenate([left_pad, np.transpose(x, (0, 2, 3, 1)), right_pad], axis=1) * mask + template

    #Predict fitness
    prob_pred = oracle.predict(x=[onehots, fake_lib, fake_d], batch_size=batch_size)[0][:, 0]
    
    logodds_pred = np.log(prob_pred / (1. - prob_pred))
    
    return logodds_pred

def cbas_opt(X_train, vae_model_path, oracle, vae_0_encoder, vae_0_decoder, weights_type='cbas',
        LD=100, iters=20, samples=500, homoscedastic=False, homo_y_var=0.1,
        quantile=0.95, verbose=False, alpha=1, train_gt_evals=None,
        cutoff=1e-6, it_epochs=10, enc1_units=50, store_every=1, yt_scale=0.8):
    
    assert weights_type in ['cbas', 'rwr']
    L = X_train.shape[1]
    
    vae_model, vae_encoder, vae_decoder = build_vae(vae_model_path)
    
    def get_samples(Xt_p):
        Xt_sampled = np.zeros_like(Xt_p)
        for i in range(Xt_p.shape[0]):
            for j in range(Xt_p.shape[2]):
                p = Xt_p[i, 0, j, :]
                k = np.random.choice(range(len(p)), p=p)
                Xt_sampled[i, 0, j, k] = 1.
        return Xt_sampled

    generated_sequences = []
    n_top = 0
    y_star = -np.inf
    
    for t in range(iters):
        ### Take Samples ###
        zt = np.random.randn(samples, LD)
        zt_dummy = np.zeros((samples, 1))
        if t > 0:
            Xt_p = vae_decoder.predict([zt_dummy, zt])[1]
            Xt = get_samples(Xt_p)
        else:
            Xt = X_train
        
        ### Evaluate oracle ###
        yt = _templated_predict(oracle, Xt, batch_size=32)
        
        ### Calculate weights ###
        if t > 0:
            if weights_type == 'cbas': 
                log_pxt = np.sum(np.log(Xt_p) * Xt, axis=(1, 2, 3))
                X0_p = vae_0_decoder.predict([zt_dummy, zt])[1]
                log_px0 = np.sum(np.log(X0_p) * Xt, axis=(1, 2, 3))
                w1 = np.exp(log_px0-log_pxt)
                y_star_1 = np.percentile(yt, quantile*100)
                if y_star_1 > y_star:
                    y_star = y_star_1
                w2= scipy.stats.norm.sf(y_star, loc=yt, scale=yt_scale)
                weights = w1*w2
            elif weights_type == 'rwr':
                weights = np.exp(alpha*yt)
                weights /= np.sum(weights)
                weights *= Xt.shape[0]
        else:
            weights = np.ones(yt.shape[0])
            
        if t % store_every == 0 :
            Xt_seqs = []
            nt_map_inv = {0:'A', 1:'C', 2:'G', 3:'T'}
            
            for i in range(Xt.shape[0]) :
                xt_seq = ''
                for j in range(Xt.shape[2]) :
                    argmax_j = np.argmax(Xt[i, 0, j, :])
                    xt_seq += nt_map_inv[argmax_j]
                
                xt_seq = sequence_template[:25] + xt_seq[:45] + "AATAAA" + xt_seq[45+6:] + sequence_template[25+96:]
                
                Xt_seqs.append(xt_seq)
            
            generated_sequences.append(Xt_seqs)
        
        if verbose:
            print(weights_type.upper(), t, np.median(yt))
        
        ### Train model ###
        if t == 0:
            vae_encoder.load_weights(vae_model_path + "_encoder.h5", by_name=True)
            vae_decoder.load_weights(vae_model_path + "_decoder.h5", by_name=True)
        else:
            cutoff_idx = np.where(weights < cutoff)
            Xt = np.delete(Xt, cutoff_idx, axis=0)
            yt = np.delete(yt, cutoff_idx, axis=0)
            weights = np.delete(weights, cutoff_idx, axis=0)
            
            dummy_train = np.zeros((Xt.shape[0], 1))
            
            # train the autoencoder
            _ = vae_model.fit(
                [dummy_train, Xt],
                [dummy_train, dummy_train],
                shuffle=False,
                sample_weight=[weights, weights],
                epochs=it_epochs,
                batch_size=32,
                verbose=0
            )
    
    return generated_sequences


In [7]:
#Load cached dataframe

n_train = 5000
n_test = 1000

n_seqs = n_train + n_test

seqs = []
with open('../vae/apa_simple_seqs_strong.txt', 'rt') as f :
    for line_raw in f :
        line = line_raw.strip()
        seqs.append(line.split("\t")[0])
        
        if len(seqs) >= n_seqs :
            break

print("len(seqs) = " + str(len(seqs)) + " (loaded)")

short_encoder = isol.OneHotEncoder(96)

x_train = np.concatenate([np.expand_dims(np.expand_dims(short_encoder(seq[25:25+96]), axis=0), axis=0) for seq in seqs[:n_train]], axis=0)
x_test = np.concatenate([np.expand_dims(np.expand_dims(short_encoder(seq[25:25+96]), axis=0), axis=0) for seq in seqs[n_train:]], axis=0)

print(x_train.shape)
print(x_test.shape)


len(seqs) = 6000 (loaded)
(5000, 1, 96, 4)
(1000, 1, 96, 4)


In [8]:
#Load predictor

predictor_path = '../../../../aparent/saved_models/aparent_plasmid_iso_cut_distalpas_all_libs_no_sampleweights_sgd.h5'

oracle = load_model(predictor_path)


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [9]:
#Load models

vae_path = "../vae/saved_models/vae_apa_max_isoform_simple_strong_len_96_latent_100_epochs_50_kl_factor_1125_annealed"

vae_0_encoder = load_model(vae_path + '_encoder.h5', custom_objects={'st_sampled_softmax':st_sampled_softmax, 'st_hardmax_softmax':st_hardmax_softmax, 'min_pred':lambda y_true,y_pred:y_pred})
vae_0_decoder = load_model(vae_path + '_decoder.h5', custom_objects={'st_sampled_softmax':st_sampled_softmax, 'st_hardmax_softmax':st_hardmax_softmax, 'min_pred':lambda y_true,y_pred:y_pred})


Instructions for updating:
Use tf.random.categorical instead.


In [None]:

vae_prefix_str = "_epochs_50_kl_factor_1125"

weights_type = 'cbas'
run_ix = 0
n_epochs = 150
n_samples = 1000
quantile = 0.8
yt_scale = 0.1
alpha = 0.5
it_epochs = 1

generated_sequences = cbas_opt(x_train, vae_path, oracle, vae_0_encoder, vae_0_decoder,
        LD=100, iters=n_epochs, samples=n_samples, weights_type=weights_type, alpha=alpha,
        quantile=quantile, verbose=True, cutoff=1e-6, it_epochs=it_epochs, store_every=1, yt_scale=yt_scale)


In [37]:

experiment_name = "apa_" + weights_type + "_vae" + vae_prefix_str + "_iters_" + str(n_epochs) + "_samples_" + str(n_samples) + "_q_" + str(quantile).replace(".", "") + "_yt_scale_" + str(yt_scale).replace(".", "") + "_alpha_" + str(alpha).replace(".", "") + "_it_epochs_" + str(it_epochs) + "_run_" + str(run_ix)

if not os.path.isdir('cbas/' + experiment_name):
    os.makedirs('cbas/' + experiment_name)

for epoch_i in range(n_epochs) :
    with open('cbas/' + experiment_name + "/" + "iter_" + str(epoch_i) + '.txt', 'wt') as f :
        for seq in generated_sequences[epoch_i] :
            f.write(seq + "\n")
