In [1]:
import os
import sys
import warnings
import random
import copy
import pickle
import glob

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

sys.path.append('../common')
import data_io_utils
import paths
import constants
import utils

sys.path.append('../A003_policy_optimization/')
import models
import A003_common

import A006_common
from unirep import babbler1900 as babbler
import sequence_selection


%reload_ext autoreload
%autoreload 2

## CONFIG

In [2]:
PROTEIN = 'BLAC' # 'GFP' or 'BLAC'

if PROTEIN == 'GFP':
    WT_SEQ = constants.AVGFP_AA_SEQ
    data_dir = os.path.join(data_io_utils.S3_DATA_ROOT, 'chip_1/simulated_annealing/GFP')
elif PROTEIN == 'BLAC':
    WT_SEQ = constants.BETA_LAC_AA_SEQ
    data_dir = os.path.join(data_io_utils.S3_DATA_ROOT, 'chip_1/simulated_annealing/beta_lactamase')
    
print(data_dir)

/notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase


## Sync data

In [3]:
data_io_utils.sync_s3_path_to_local(data_dir) # Very large sync ~300GB

In [4]:
data_io_utils.sync_s3_path_to_local(paths.EVOTUNING_CKPT_DIR)

## UniRep utility functions

In [5]:
UNIREP_BATCH_SIZE = 3500

TOP_MODEL_ENSEMBLE_NMEMBERS = models.ENSEMBLED_RIDGE_PARAMS['n_members']
TOP_MODEL_SUBSPACE_PROPORTION = models.ENSEMBLED_RIDGE_PARAMS['subspace_proportion']
TOP_MODEL_NORMALIZE = models.ENSEMBLED_RIDGE_PARAMS['normalize']
TOP_MODEL_PVAL_CUTOFF = models.ENSEMBLED_RIDGE_PARAMS['pval_cutoff']

if PROTEIN == 'GFP':
    print('GFP')
    def load_base_model(model_name):
        if model_name == 'ET_Global_Init_1':
            base_model = babbler(batch_size=UNIREP_BATCH_SIZE, model_path=paths.GFP_ET_GLOBAL_INIT_1_WEIGHT_PATH)
            print('Loading weights from:', paths.GFP_ET_GLOBAL_INIT_1_WEIGHT_PATH)
        elif model_name == 'ET_Global_Init_2':
            base_model = babbler(batch_size=UNIREP_BATCH_SIZE, model_path=paths.GFP_ET_GLOBAL_INIT_2_WEIGHT_PATH)
            print('Loading weights from:', paths.GFP_ET_GLOBAL_INIT_2_WEIGHT_PATH)
        elif model_name == 'ET_Random_Init_1':
            base_model = babbler(batch_size=UNIREP_BATCH_SIZE, model_path=paths.GFP_ET_RANDOM_INIT_1_WEIGHT_PATH)
            print('Loading weights from:', paths.GFP_ET_RANDOM_INIT_1_WEIGHT_PATH)
        elif model_name =='OneHot':
            # Just need it to generate one-hot reps.
            # Top model created within OneHotRegressionModel doesn't actually get used.
            base_model = models.OneHotRegressionModel('EnsembledRidge') 
        else:
            assert False, 'Unsupported base model'

        return base_model
    
elif PROTEIN == 'BLAC':
    print('BLAC')
    class BetaLacOneHotEncoder(object):
        def __init__(self):
            pass

        def encode_seqs(self, seqs):
            return utils.encode_aa_seq_list_as_matrix_of_flattened_one_hots(seqs)
    
    def load_base_model(model_name):
        if model_name == 'ET_Global_Init_1':
            base_model = babbler(batch_size=UNIREP_BATCH_SIZE, model_path=paths.BLAC_ET_GLOBAL_INIT_1_WEIGHT_PATH)
            print('Loading weights from:', paths.BLAC_ET_GLOBAL_INIT_1_WEIGHT_PATH)
        elif model_name == 'ET_Random_Init_1':
            base_model = babbler(batch_size=UNIREP_BATCH_SIZE, model_path=paths.BLAC_ET_RANDOM_INIT_1_WEIGHT_PATH)
            print('Loading weights from:', paths.BLAC_ET_RANDOM_INIT_1_WEIGHT_PATH)
        elif model_name =='OneHot':
            # Just need it to generate one-hot reps.
            # Doing it this way to be consistent with the GFP pipeline
            base_model = BetaLacOneHotEncoder()
        else:
            assert False, 'Unsupported base model'

        return base_model

# Generate representations
def generate_reps(seq_list, base_model, sess):        
    if 'babbler1900' == base_model.__class__.__name__:
        assert len(seq_list) <= UNIREP_BATCH_SIZE
        hiddens = base_model.get_all_hiddens(seq_list, sess)
        rep = np.stack([np.mean(s, axis=0) for s in hiddens],0)

    else: # one hot model
        rep = base_model.encode_seqs(seq_list)

    return rep

BLAC


## Plotting utility functions

In [6]:
def sr_vs_nsr_pred_plot(nsr_yhat_wt, sr_yhat_wt, nsr_yhat_top, sr_yhat_top, output_dir):
    fig = plt.figure()
    plt.plot(sr_yhat_top, nsr_yhat_top, '.k')
    plt.axvline(sr_yhat_wt, color='r')
    plt.axhline(nsr_yhat_wt, color='r')

    plt.text(x=0.6, y=0.9,  s='SR predicted top seqs >WT: %d' % np.sum(sr_yhat_top > sr_yhat_wt),
            horizontalalignment='right', verticalalignment='center', transform=plt.gca().transAxes)
    plt.text(x=0.6, y=0.8,  s='NSR predicted top seqs >WT: %d' % np.sum(nsr_yhat_top > nsr_yhat_wt),
        horizontalalignment='right', verticalalignment='center', transform=plt.gca().transAxes)

    plt.xlabel('Predicted qfunc, original sparse refit top model')
    plt.ylabel('Predicted qfunc, non sparse refit top model')

    out_file = os.path.join(output_dir, 'sr_vs_nsr_plot.png')
    plt.savefig(out_file, bbox_inches='tight')
    plt.close(fig)

def top_seq_and_traj_plot(fit_mat, trajectory_indices_yielding_top_seqs, 
                          seq_indices_inside_top_trajectories, output_dir):
    fit_mat = res_sa['fitness_history']
    
    fig = plt.figure()
    for i in range(3):
        plt.plot(fit_mat[:, trajectory_indices_yielding_top_seqs[i]], '-')
        plt.plot(seq_indices_inside_top_trajectories[i], 
                 fit_mat[seq_indices_inside_top_trajectories[i], trajectory_indices_yielding_top_seqs[i]], '.k')
        
    out_file = os.path.join(output_dir, 'top_seq_and_traj_plot.png')
    plt.savefig(out_file, bbox_inches='tight')
    plt.close(fig)
    
def qfunc_hist_plot(wt_qfunc, init_fitness, top_seq_fitness, output_dir):
    init_fitness[np.isinf(init_fitness)] = 0 
    top_seq_fitness[np.isinf(top_seq_fitness)] = 0 
    
    fig = plt.figure()
    plt.hist(init_fitness, bins=50, color='b', alpha=0.3)
    plt.hist(top_seq_fitness, bins=50, color='r', alpha=0.3)
    plt.axvline(wt_qfunc, color='k')
    
    plt.text(x=0.6, y=0.9,  s='Num initial seqs >WT: %d' % np.sum(init_fitness > wt_qfunc),
        horizontalalignment='right', verticalalignment='center', transform=plt.gca().transAxes)
    plt.text(x=0.6, y=0.8,  s='Num optimized seqs >WT: %d' % np.sum(top_seq_fitness > wt_qfunc),
        horizontalalignment='right', verticalalignment='center', transform=plt.gca().transAxes)
    
    out_file = os.path.join(output_dir, 'qfunc_hist_plot.png')
    plt.savefig(out_file, bbox_inches='tight')
    plt.close(fig)
    
def seq_dist_summary_plots(top_seqs, top_seq_fitness, top_seq_fitness_ensemble, wt_qfunc, output_dir):
    ref_seq = WT_SEQ
    
    pct = np.percentile(top_seq_fitness_ensemble, [5, 95], axis=1).T
    errb = pct - top_seq_fitness.reshape((-1,1))
    errb[:,0] = -errb[:,0]

    ld_mat = utils.levenshtein_distance_matrix(top_seqs)
    ld_sidx = np.argsort(-np.mean(ld_mat, axis=0))
    ld_mat[ld_mat == 0] = np.inf
    ld_mat = ld_mat[ld_sidx]
    ld_mat = ld_mat[:, ld_sidx]

    fr = np.max(top_seq_fitness) - np.min(top_seq_fitness)

    fig, (a0, a1, a2) = plt.subplots(3, 1, gridspec_kw={'height_ratios': [1,1,4],'hspace':0.025}, figsize=(2*5,2*7.25))
    a0.errorbar(np.arange(len(top_seq_fitness)), top_seq_fitness[ld_sidx], yerr=errb[ld_sidx].T, fmt='.k', zorder=-1)
    a0.axhline(wt_qfunc, color='r')
    a0.set_xlim([0-0.5, len(top_seq_fitness)-0.5])
    a0.set_xticks([])
    a0.set_ylabel('predicted\nfitness')
    a0.set_ylim([0, 1.5])
    a0.grid('on')

    ld_ref = utils.levenshtein_distance_matrix([ref_seq], top_seqs).reshape(-1)
    a1.bar(np.arange(len(ld_ref)), ld_ref[ld_sidx], width=1, color='gray')
    #a1.set_ylim(bottom=np.min(top_seq_fitness)-0.05*fr)
    a1.set_xlim([0-0.5, len(top_seq_fitness)-0.5])
    a1.set_xticks([])
    a1.set_ylabel('num. mutations\nto reference')
    a1.grid('on')


    im = a2.imshow(ld_mat, aspect='auto')
    a2.set_ylabel('sequence')
    a2.set_xlabel('sequence')

    cb_ax = fig.add_axes([0.93, 0.23, 0.02, 0.3])
    cbar = fig.colorbar(im, cax=cb_ax)
    cbar.ax.get_yaxis().labelpad = 15
    cbar.ax.set_ylabel('num. mutations', rotation=270)
    
    out_file = os.path.join(output_dir, 'seq_similarity_summary_plot.png')
    plt.savefig(out_file, bbox_inches='tight')
    plt.close(fig)
    
    fig = plt.figure()
    m = ld_mat
    m[m == np.inf] = 0
    uv,uc = np.unique(ld_mat.reshape(-1), return_counts=True)
    plt.bar(uv, uc)
    plt.title('Pairwise levenshtein distance distribution\nMedian pw lev dist=%d'%np.median(ld_mat.reshape(-1)))
    
    out_file = os.path.join(output_dir, 'pairwise_levenshtein_distance_plot.png')
    plt.savefig(out_file, bbox_inches='tight')
    plt.close(fig)

## Main sequence selection functions

In [7]:
def load_results(res_file):
    with open(res_file, 'rb') as f:
        res = pickle.load(f)
        
    res_sa = sequence_selection.convert_result_vals_to_mat(res['sa_results'])
    return res, res_sa
    

def select_top_seqs(res_file, nseq_select, burnin=250, max_sa_itr=None):
    print('SELECTION')
    res_file_name = os.path.basename(res_file)
        
    print(res_file)
    print('Loading results and converting SA histories to numpy arrays')
    res, res_sa = load_results(res_file)
    fit_mat = res_sa['fitness_history']

    init_fitness = fit_mat[0,:]

    print('Selecting top sequences')
    # First identify the best sequence in each SA trajectory.
    top_seqs, top_seq_fitness, _, top_seq_idx = sequence_selection.get_best_sequence_in_each_trajectory(
            res_sa, burnin=burnin, max_sa_itr=max_sa_itr)

    # Now, select the top seqs of the best-in-trajectory sequences. 
    # These are are our official selections!
    # top_seq_idx is an index for each trajectory that says where in the trajectory the best sequence is. 
    sidx = np.argsort(-top_seq_fitness)
    top_sidx = sidx[:nseq_select]

    trajectory_indices_yielding_top_seqs = top_sidx
    seq_indices_inside_top_trajectories = top_seq_idx[top_sidx]
    selected_top_seqs = top_seqs[top_sidx] ## official selection
    selected_top_seq_fitness = top_seq_fitness[top_sidx] ## official selection
    selected_top_ensemble_fitness_preds = []
    for i in range(len(trajectory_indices_yielding_top_seqs)):
        selected_top_ensemble_fitness_preds.append(
            res_sa['fitness_mem_pred_history'][seq_indices_inside_top_trajectories[i]][
                trajectory_indices_yielding_top_seqs[i]]
        )
        
    # Turn these selections into a dataframe
    id_prefix = res_file_name.replace('.p', '')
    fit_mat_idx = [str(s[0]) + '_' + str(s[1]) for s in list(zip(*[list(seq_indices_inside_top_trajectories), 
           list(trajectory_indices_yielding_top_seqs)]))]
    seq_ids = [id_prefix + '-seq_idx_' + fmi for fmi in fit_mat_idx]
    
    select_df = pd.DataFrame()
    select_df['id'] = seq_ids
    select_df['seq_idx'] = seq_indices_inside_top_trajectories # row idx of res_sa['fitness_history']
    select_df['trajectory_idx'] = trajectory_indices_yielding_top_seqs # col idx of res_sa['fitness_history']
    select_df['predicted_fitness'] = selected_top_seq_fitness
    select_df['ensemble_predicted_fitness'] = selected_top_ensemble_fitness_preds
    select_df['seq'] = selected_top_seqs
    
    return select_df, res, res_sa

def validate_top_seqs(select_df, output_dir, res, res_sa, burnin=250, max_sa_itr=None):
    fit_mat = res_sa['fitness_history']
    seq_mat = res_sa['seq_history']
    
    print('VALIDATION')
    trajectory_indices_yielding_top_seqs = np.array(select_df['trajectory_idx'])
    seq_indices_inside_top_trajectories = np.array(select_df['seq_idx'])
    selected_top_seq_fitness = np.array(select_df['predicted_fitness'])
    top_seq_fitness_ensemble = np.stack(select_df['ensemble_predicted_fitness'])
    selected_top_seqs = np.array(select_df['seq'])
    

    # First check that after all the manipulation we did, that manually extracting the
    # sequence and its fitness based on the identified indices lines up with the what
    # the selection code provides.
    print('Validating sequence selection doing a manual re-extraction')
    for i in range(len(trajectory_indices_yielding_top_seqs)):
        man_sel_fitness_ens = res_sa['fitness_mem_pred_history'][
            seq_indices_inside_top_trajectories[i]][trajectory_indices_yielding_top_seqs[i]]
        man_sel_fitness = fit_mat[seq_indices_inside_top_trajectories[i], 
                                  trajectory_indices_yielding_top_seqs[i]]
        man_sel_seq = seq_mat[seq_indices_inside_top_trajectories[i], 
                              trajectory_indices_yielding_top_seqs[i]]
        
        assert man_sel_fitness ==  selected_top_seq_fitness[i]
        assert man_sel_seq == selected_top_seqs[i]
        
    # Re-calculate the ensemble's prediction and make sure it lines up with the extracted
    # fitness.
    recalc_pred_fitness = np.mean(np.stack(select_df['ensemble_predicted_fitness']), axis=1)
    assert np.allclose(recalc_pred_fitness - selected_top_seq_fitness, 0, atol=1e-5)
    
    # Validate these sequences are good. Score again with the original sparse refit 
    # top model as well as the non-sparse refit model
    print('Rescoring sequences')
    print('\tLoading base model')
    tf.reset_default_graph()
    base_model_name = res_file.split('-')[1]
    if base_model_name == 'LargeMut':
        base_model_name =  res_file.split('-')[2]
    base_model = load_base_model(base_model_name)

    train_df = res['train_df']
    train_seqs = list(train_df['seq'])
    train_qfunc = np.array(train_df['quantitative_function'])

    # Generate reps for the sequences
    print('\tGenerating reps')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        train_reps = generate_reps(train_seqs, base_model, sess)
        top_seq_reps = generate_reps(list(selected_top_seqs), base_model, sess)
        wt_seq_rep = generate_reps([WT_SEQ], base_model, sess)

    # Train an NSR top model
    print('\tTraining an NSR top model')
    nsr_top_model = A003_common.train_ensembled_ridge(
        train_reps, 
        train_qfunc, 
        n_members=TOP_MODEL_ENSEMBLE_NMEMBERS, 
        subspace_proportion=TOP_MODEL_SUBSPACE_PROPORTION,
        normalize=TOP_MODEL_NORMALIZE, 
        do_sparse_refit=False, 
        pval_cutoff=TOP_MODEL_PVAL_CUTOFF
    )

    sr_top_model = res['top_model']

    # Score WT seqs
    nsr_yhat_wt = nsr_top_model.predict(wt_seq_rep)
    sr_yhat_wt = sr_top_model.predict(wt_seq_rep)

    # Score the the top sequences.
    nsr_yhat_top = nsr_top_model.predict(top_seq_reps)
    sr_yhat_top = sr_top_model.predict(top_seq_reps)

    # First make sure that the freshly predicted fitness of the top seqs match the recorded ones.
    assert np.corrcoef(sr_yhat_top, selected_top_seq_fitness)[0,1] > 0.99

    print('Generating validation plots')
    ## Now generate a bunch of plots
    sr_vs_nsr_pred_plot(nsr_yhat_wt, sr_yhat_wt, nsr_yhat_top, sr_yhat_top, output_dir)
    top_seq_and_traj_plot(fit_mat, trajectory_indices_yielding_top_seqs, 
            seq_indices_inside_top_trajectories, output_dir)
    
    # all fitnesses for best-in-trajectory sequences
    _, all_top_seq_fitness, _, _ = sequence_selection.get_best_sequence_in_each_trajectory(
            res_sa, burnin=burnin, max_sa_itr=max_sa_itr)
    qfunc_hist_plot(sr_yhat_wt, fit_mat[0], all_top_seq_fitness, output_dir)
    
    seq_dist_summary_plots(list(selected_top_seqs), selected_top_seq_fitness, 
                          top_seq_fitness_ensemble, sr_yhat_wt, output_dir)

## Do the sequence selection

In [8]:
BURNIN = 250
MAX_SA_ITR = None # if None use all of them.
NSEQ_SELECT = 320

In [9]:
result_files = sorted(glob.glob(os.path.join(data_dir, PROTEIN + '_SimAnneal*.p')))

# Special case globs
#result_files = sorted(glob.glob(os.path.join(data_dir, '*SparseRefit_False*.p')))

print(len(result_files))

33


In [10]:
for res_file in result_files:
    print()
    output_dir = res_file.replace('.p', '-selected_seqs')
    output_file = os.path.join(output_dir, 'selected_seqs_df.pkl')
    
    if not os.path.exists(output_file):
        os.makedirs(output_dir, exist_ok=True)

        select_df, res, res_sa = select_top_seqs(res_file, NSEQ_SELECT, burnin=BURNIN, max_sa_itr=MAX_SA_ITR)
        validate_top_seqs(select_df, output_dir, res, res_sa, burnin=BURNIN, max_sa_itr=MAX_SA_ITR)

        select_df.to_pickle(output_file)
    else:
        print('Already done:', output_file)


Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase/BLAC_SimAnneal-ET_Global_Init_1-0024-00-3e721641-selected_seqs/selected_seqs_df.pkl

Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase/BLAC_SimAnneal-ET_Global_Init_1-0024-01-3a0e3d4-selected_seqs/selected_seqs_df.pkl

Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase/BLAC_SimAnneal-ET_Global_Init_1-0024-02-31e54146-selected_seqs/selected_seqs_df.pkl

Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase/BLAC_SimAnneal-ET_Global_Init_1-0024-03-3764e943-selected_seqs/selected_seqs_df.pkl

Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lactamase/BLAC_SimAnneal-ET_Global_Init_1-0024-04-4502d3-selected_seqs/selected_seqs_df.pkl

Already done: /notebooks/analysis/common/../../data/s3/chip_1/simulated_annealing/beta_lacta

In [11]:
data_io_utils.sync_local_path_to_s3(data_dir)