In [2]:
import sys
import os
import itertools
from keras.layers import Input, Dense, Reshape, Flatten
from keras import layers, initializers
from keras.models import Model, load_model
import keras.backend as K
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
from scipy.stats import norm
from scipy.optimize import minimize
from keras.utils.generic_utils import get_custom_objects
import json
#import tensorflow_probability as tfp

#tfd = tfp.distributions

def subselect_list(li, ixs) :
    return [
        li[ixs[k]] for k in range(len(ixs))
    ]

class IdentityEncoder :
    
    def __init__(self, seq_len, channel_map) :
        self.seq_len = seq_len
        self.n_channels = len(channel_map)
        self.encode_map = channel_map
        self.decode_map = {
            nt: ix for ix, nt in self.encode_map.items()
        }
    
    def encode(self, seq) :
        encoding = np.zeros((self.seq_len, self.n_channels))
        
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.

        return encoding
    
    def encode_inplace(self, seq, encoding) :
        for i in range(len(seq)) :
            if seq[i] in self.encode_map :
                channel_ix = self.encode_map[seq[i]]
                encoding[i, channel_ix] = 1.
    
    def encode_inplace_sparse(self, seq, encoding_mat, row_index) :
        raise NotImplementError()
    
    def decode(self, encoding) :
        seq = ''
    
        for pos in range(0, encoding.shape[0]) :
            argmax_nt = np.argmax(encoding[pos, :])
            max_nt = np.max(encoding[pos, :])
            seq += self.decode_map[argmax_nt]

        return seq
    
    def decode_sparse(self, encoding_mat, row_index) :
        raise NotImplementError()

from keras.backend.tensorflow_backend import set_session

def contain_tf_gpu_mem_usage() :
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    set_session(sess)

contain_tf_gpu_mem_usage()


In [3]:
import itertools
from keras.layers import Input, Dense, Reshape, Flatten
from keras import layers, initializers
from keras.models import Model, load_model
from seqtools import SequenceTools as ST
from util import AA, AA_IDX
from util import build_vae
from keras.utils.generic_utils import get_custom_objects
from util import one_hot_encode_aa, partition_data, get_balaji_predictions, get_samples, get_argmax
from util import convert_idx_array_to_aas, build_pred_vae_model, get_experimental_X_y
from util import get_gfp_X_y_aa
from losses import neg_log_likelihood

from gfp_gp import SequenceGP
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

def build_model(M):
    x = Input(shape=(M, 20,))
    y = Flatten()(x)
    y = Dense(50, activation='elu')(y)
    y = Dense(2)(y)
    model = Model(inputs=x, outputs=y)
    return model


In [4]:
#Specfiy problem-specific parameters

it = 1

TRAIN_SIZE = 5000
train_size_str = "%ik" % (TRAIN_SIZE/1000)
num_models = [1, 5, 20][it]
RANDOM_STATE = it + 1

X_train, y_train, gt_train  = get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE)

L = X_train.shape[1]

vae_suffix = '_%s_%i' % (train_size_str, RANDOM_STATE)
oracle_suffix = '_%s_%i_%i' % (train_size_str, num_models, RANDOM_STATE)

AA = ['a', 'r', 'n', 'd', 'c', 'q', 'e', 'g', 'h', 'i', 'l', 'k', 'm', 'f', 'p', 's', 't', 'w', 'y', 'v']
residue_map = {key.upper() : val for val, key in enumerate(AA)}
seq_encoder = IdentityEncoder(237, residue_map)


In [13]:

def _templated_predict(oracles, x, batch_size=32) :
    
    #Predict fitness
    score_pred = np.mean(np.concatenate([oracles[i].predict(x=[x], batch_size=batch_size)[:, 0].reshape(-1, 1)], axis=1), axis=1)
    
    return score_pred

def fb_opt(X_train, vae_suffix, oracles, vae_0, vae_0_encoder, vae_0_decoder, weights_type='fbvae',
        LD=100, iters=20, samples=500, 
        quantile=0.8, verbose=False, train_gt_evals=None,
        it_epochs=10, enc1_units=50, store_every=1):
    
    assert weights_type in ['fbvae']
    L = X_train.shape[1]
    
    vae_model = build_vae(latent_dim=20, n_tokens=20, seq_length=237, enc1_units=50)

    vae_model.encoder_.load_weights("models/vae_0_encoder_weights%s.h5" % vae_suffix)
    vae_model.decoder_.load_weights("models/vae_0_decoder_weights%s.h5"% vae_suffix)
    vae_model.vae_.load_weights("models/vae_0_vae_weights%s.h5"% vae_suffix)
    vae = vae_model.vae_
    vae_encoder = vae_model.encoder_
    vae_decoder = vae_model.decoder_
    
    def get_samples(Xt_p):
        Xt_sampled = np.zeros_like(Xt_p)
        for i in range(Xt_p.shape[0]):
            for j in range(Xt_p.shape[1]):
                p = Xt_p[i, j, :]
                k = np.random.choice(range(len(p)), p=p)
                Xt_sampled[i, j, k] = 1.
        return Xt_sampled

    generated_sequences = []
    fb_thresh = -np.inf
    n_top = 0
    
    for t in range(iters):
        ### Take Samples and evaluate ground truth and oracle ##
        zt = np.random.randn(samples, LD)
        zt_dummy = np.zeros((samples, 1))
        if t > 0:
            Xt_sample_p = vae_decoder.predict([zt])
            Xt_sample = get_samples(Xt_sample_p)
            yt_sample = _templated_predict(oracles, Xt_sample, batch_size=32)
        else:
            Xt_sample_p = vae_0_decoder.predict([zt])
            Xt_sample = get_samples(Xt_sample_p)
            yt_sample = _templated_predict(oracles, Xt_sample, batch_size=32)
            Xt = X_train
            yt = _templated_predict(oracles, Xt, batch_size=32)
            fb_thresh = np.percentile(yt, quantile*100)
        
        ### Calculate threshold ###
        if t > 0:
            threshold_idx = np.where(yt_sample >= fb_thresh)[0]
            n_top = len(threshold_idx)
            sample_arrs = [Xt_sample, yt_sample]
            full_arrs = [Xt, yt]
            
            for l in range(len(full_arrs)):
                sample_arr = sample_arrs[l]
                full_arr = full_arrs[l]
                sample_top = sample_arr[threshold_idx]
                full_arr = np.concatenate([sample_top, full_arr])
                full_arr = np.delete(full_arr, range(full_arr.shape[0]-n_top, full_arr.shape[0]), axis=0)
                full_arrs[l] = full_arr
            Xt, yt = full_arrs
        
        if t % store_every == 0 :
            Xt_sample_seqs = []
            AA = ['a', 'r', 'n', 'd', 'c', 'q', 'e', 'g', 'h', 'i', 'l', 'k', 'm', 'f', 'p', 's', 't', 'w', 'y', 'v']
            nt_map_inv = {key : val.upper() for key, val in enumerate(AA)}

            for i in range(Xt_sample.shape[0]) :
                xt_seq = ''
                for j in range(Xt_sample.shape[1]) :
                    argmax_j = np.argmax(Xt_sample[i, j, :])
                    xt_seq += nt_map_inv[argmax_j]
                
                Xt_sample_seqs.append(xt_seq)
            
            generated_sequences.append(Xt_sample_seqs)
        
        if verbose:
            print(weights_type.upper(), t, fb_thresh, np.median(yt_sample), n_top)
        
        ### Train model ###
        if t == 0:
            vae_encoder.set_weights(vae_0_encoder.get_weights())
            vae_decoder.set_weights(vae_0_decoder.get_weights())
            vae.set_weights(vae_0.get_weights())
        else:
            # train the autoencoder
            _ = vae.fit(
                [Xt], [Xt, np.zeros(Xt.shape[0])],
                shuffle=False,
                epochs=1,
                batch_size=32,
                verbose=1
            )
    
    return generated_sequences


In [14]:
#Load predictor

oracles = [build_model(L) for i in range(num_models)]
for i in range(num_models) :
    oracles[i].load_weights("models/oracle_%i%s.h5" % (i, oracle_suffix))


In [15]:
#Load models

vae_0 = build_vae(latent_dim=20, n_tokens=20, seq_length=237, enc1_units=50)

vae_0.encoder_.load_weights("models/vae_0_encoder_weights%s.h5" % vae_suffix)
vae_0.decoder_.load_weights("models/vae_0_decoder_weights%s.h5"% vae_suffix)
vae_0.vae_.load_weights("models/vae_0_vae_weights%s.h5"% vae_suffix)

vae_0_encoder = vae_0.encoder_
vae_0_decoder = vae_0.decoder_
vae_0_vae = vae_0.vae_


In [None]:

vae_prefix_str = ""

n_epochs = 150
n_samples = 1000
quantile = 0.8

generated_sequences = fb_opt(X_train, vae_suffix, oracles, vae_0_vae, vae_0_encoder, vae_0_decoder,
        LD=20, iters=n_epochs, samples=n_samples, 
        quantile=quantile, verbose=True, store_every=1)


In [18]:

seed_suffix = ""

experiment_name = "gfp_fb_vae_weak_balaji" + vae_prefix_str + "_iters_" + str(n_epochs) + "_samples_" + str(n_samples) + "_q_" + str(quantile).replace(".", "") + seed_suffix

if not os.path.isdir('fbvae_weak_balaji/' + experiment_name):
    os.makedirs('fbvae_weak_balaji/' + experiment_name)

for epoch_i in range(n_epochs) :
    with open('fbvae_weak_balaji/' + experiment_name + "/" + "iter_" + str(epoch_i) + '.txt', 'wt') as f :
        for seq in generated_sequences[epoch_i] :
            f.write(seq + "\n")


In [None]:

vae_prefix_str = ""

n_epochs = 150
n_samples = 1000
quantile = 0.8

generated_sequences = fb_opt(X_train, vae_suffix, oracles, vae_0_vae, vae_0_encoder, vae_0_decoder,
        LD=20, iters=n_epochs, samples=n_samples, 
        quantile=quantile, verbose=True, store_every=1)


In [20]:

seed_suffix = "_retry_1"

experiment_name = "gfp_fb_vae_weak_balaji" + vae_prefix_str + "_iters_" + str(n_epochs) + "_samples_" + str(n_samples) + "_q_" + str(quantile).replace(".", "") + seed_suffix

if not os.path.isdir('fbvae_weak_balaji/' + experiment_name):
    os.makedirs('fbvae_weak_balaji/' + experiment_name)

for epoch_i in range(n_epochs) :
    with open('fbvae_weak_balaji/' + experiment_name + "/" + "iter_" + str(epoch_i) + '.txt', 'wt') as f :
        for seq in generated_sequences[epoch_i] :
            f.write(seq + "\n")


In [None]:

vae_prefix_str = ""

n_epochs = 150
n_samples = 1000
quantile = 0.8

generated_sequences = fb_opt(X_train, vae_suffix, oracles, vae_0_vae, vae_0_encoder, vae_0_decoder,
        LD=20, iters=n_epochs, samples=n_samples, 
        quantile=quantile, verbose=True, store_every=1)


In [22]:

seed_suffix = "_retry_2"

experiment_name = "gfp_fb_vae_weak_balaji" + vae_prefix_str + "_iters_" + str(n_epochs) + "_samples_" + str(n_samples) + "_q_" + str(quantile).replace(".", "") + seed_suffix

if not os.path.isdir('fbvae_weak_balaji/' + experiment_name):
    os.makedirs('fbvae_weak_balaji/' + experiment_name)

for epoch_i in range(n_epochs) :
    with open('fbvae_weak_balaji/' + experiment_name + "/" + "iter_" + str(epoch_i) + '.txt', 'wt') as f :
        for seq in generated_sequences[epoch_i] :
            f.write(seq + "\n")
