In [104]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [105]:
import os

# Limit memory usage to 50% (adjust as needed)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7"


import pickle
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from direvo_functions import *
from ruggedness_functions import *
import selection_function_library as slct
import tqdm
from scipy.optimize import curve_fit
import pandas as pd

In [106]:
# Define the SLIDE_data path
slide_data_dir = "/home/jess/Documents/SLIDE_data"


In [107]:
# The DE function

def directedEvolution(rng,
                      selection_strategy, 
                      selection_params, 
                      empirical = False, 
                      N = None, 
                      K = None, 
                      landscape = None, 
                      popsize=100, 
                      mut_chance=0.01, 
                      num_steps=50, 
                      num_reps=10, 
                      define_i_pop=None, 
                      pre_optimisation_steps=0,
                      average=True):
    
 
    r1,r2,r3=jr.split(rng,3)

    # Get initial population.
    if define_i_pop == None:
        i_pop = jnp.array([jr.randint(r1, (N,), 0, 2)]*popsize)
    else:
        i_pop = define_i_pop
 
    # Function for evaluating fitness.
    if empirical:
        fitness_function = build_empirical_landscape_function(landscape)
        mutation_function = build_mutation_function(mut_chance, 20)
    else:
        fitness_function = build_NK_landscape_function(r2, N, K)
        mutation_function = build_mutation_function(mut_chance, 2)
 
    # Define selection function.
    selection_function = build_selection_function(
        selection_strategy, selection_params)
    
    if pre_optimisation_steps!= 0:

        pre_op_selection_function=build_selection_function(slct.base_chance_threshold_select, {'base_chance':0.0, 'threshold':0.95})


        pre_op = run_directed_evolution(r3, i_pop=i_pop, 
                               selection_function=pre_op_selection_function, 
                               mutation_function=mutation_function, 
                               fitness_function=fitness_function, 
                               num_steps=pre_optimisation_steps)[1]
 
        i_pop = pre_op['pop'][-1]

    # Bringing it all together.
    vmapped_run = jax.jit(jax.vmap(lambda r: run_directed_evolution(
        r, i_pop, selection_function, mutation_function, fitness_function=fitness_function, num_steps=num_steps)[1]))
    
    # The array of seeds we will take as input.
    rng_seeds = jr.split(r3, num_reps)
    results = vmapped_run(rng_seeds)
 
    return results

## NK Ruggedness Prediction

In [108]:
def NK_grid(N_range, num_samples=10):
    N = jnp.linspace(N_range[0], N_range[1], num=num_samples)
    K = jnp.array([jnp.linspace(1, i, num_samples)
                  for i in N]).reshape(num_samples, num_samples)
    N = jnp.repeat(N, num_samples).reshape(num_samples, num_samples)
    return N, K

N_grid, K_grid = NK_grid([10, 50])

Ns, Ks = N_grid.flatten(), K_grid.flatten()
Ns = jnp.flip(Ns)
Ks = jnp.flip(Ks)
NKs = list(zip(Ns,Ks))

In [109]:
## Large sweep data

file_path = os.path.join(slide_data_dir, "large_strategy_sweep.pkl")
with open(file_path, "rb") as f:
    strategy_data = pickle.load(f)

file_path = os.path.join(slide_data_dir, "large_decay_curve_sweep.pkl")
with open(file_path, "rb") as f:
    decay_data = pickle.load(f)

### Example fitness decay curves

In [110]:
NK_pair = [(20,14), (20,1)]
#NK_pair = [(40,15), (40,2)]
out = []
params = {'threshold': 0.0, 'base_chance' : 1.0}
for nk in NK_pair:
    run = directedEvolution(jr.PRNGKey(0),
                            N = int(nk[0]),
                            K=int(nk[1]),
                            selection_strategy=slct.base_chance_threshold_select,
                            selection_params = params,
                            popsize=int(300),#int(150),
                            mut_chance=0.5/int(nk[0]),#0.1/int(nk[0]),
                            num_steps=25,
                            num_reps=1,
                            pre_optimisation_steps=20,
                            average=True)
    out.append(run['fitness'].mean(axis=-1))

smooth_rugged = np.array([out[i][0]/out[i][0][0] for i in range(2)]) # Normalising to decay from 1
smooth_rugged_decay_rates = np.array([get_single_decay_rate(i, mut = 0.5) for i in smooth_rugged])
mutations = np.linspace(0.0,12,25)
fitted_lines = np.array([( np.exp(-mutations*(i[0]))*(1 - i[-1]) + i[-1]) for i in smooth_rugged_decay_rates])

In [111]:
with open('plot_data/smooth_rugged_example.pkl', 'wb') as f:
    pickle.dump((smooth_rugged, fitted_lines), f)

### Ruggedness prediction accuracy over $K/N$

In [112]:
reshaped_decay_curves = decay_data.reshape(100,-1,25)
normalized_decay_curves = reshaped_decay_curves / reshaped_decay_curves[:, :, 0][:, :, np.newaxis]

In [113]:
decay_rates = np.zeros((normalized_decay_curves.shape[0], normalized_decay_curves.shape[1]))

for i in range(normalized_decay_curves.shape[0]):
    for ii in range(normalized_decay_curves.shape[1]):
        decay_rates[i,ii] = get_single_decay_rate(normalized_decay_curves[i,ii,:], mut=0.5)[0]

In [114]:
k_plus_one_over_ns = np.clip((np.array(NKs)[:,1] + 1)/np.array(NKs)[:,0], 0, 1)
k_over_ns = np.clip((np.array(NKs)[:,1])/np.array(NKs)[:,0], 0, 1)

In [115]:
with open('plot_data/ruggedness_accuracy.pkl', 'wb') as f:
    pickle.dump((k_plus_one_over_ns, decay_rates), f)

### Exploring accuracy over popsize

In [116]:
file_path = os.path.join(slide_data_dir, "popsize_accuracy.pkl")
with open(file_path, "rb") as f:
    popsize_data = pickle.load(f)

In [117]:
reshaped_popsize_curves = popsize_data.reshape(25,-1,25)
popsize_curves = reshaped_popsize_curves / reshaped_popsize_curves[:,:, 0][:, :, np.newaxis]

In [118]:
popsize_decay_rates = np.zeros((25,500))
for i in range(25):
    for ii in range(500):
        popsize_decay_rates[i,ii] = get_single_decay_rate(popsize_curves[i,ii,:], mut=1.0)[0]

pops=np.linspace(100,2500,25, dtype=int)

In [119]:
with open('plot_data/popsize_accuracy.pkl', 'wb') as f:
    pickle.dump((popsize_decay_rates, pops), f)

### Exploring accuracy over mutation rate

In [120]:
file_path = os.path.join(slide_data_dir, "mutation_rate_accuracy.pkl")
with open(file_path, "rb") as f:
    mut_data = pickle.load(f)

In [121]:
reshaped_mut_curves = mut_data.reshape(25,-1,25)
mut_curves = reshaped_mut_curves / reshaped_mut_curves[:,:, 0][:, :, np.newaxis]

In [122]:
muts=np.linspace(0.01,2,25)

mut_decay_rates = np.zeros((25,500))
for i in range(25):
    for ii in range(500):
        mut_decay_rates[i,ii] = get_single_decay_rate(mut_curves[i,ii,:], mut=muts[i])[0]

In [123]:
with open('plot_data/mut_accuracy.pkl', 'wb') as f:
    pickle.dump((mut_decay_rates, muts), f)

### Comparison of ruggedness metrics on NK

In [124]:
test_NKs = np.array([[12,0],[12,1],[12,2],[12,3],[12,4],[12,5],[12,6],[12,7],[12,8],[12,9],[12,10],[12,11]])
shape = (2,) * int(test_NKs[0][0])
NK_roughness_to_slope = []
NK_fourier = []
convergence_rates = []
convergence_rates_extra = []
NK_paths_to_max = []
NK_closest_max = []
NK_local_epistasis = []
rng = jr.PRNGKey(42)
rng_list = jr.split(rng, 10)

for N, K in test_NKs:
    print(N, K)
    for r in tqdm.tqdm(rng_list):
        N = int(N)
        K = int(K)
        complete_landscape = get_nk_l_o_shape(r, N, K,shape)
        NK_roughness_to_slope.append(roughness_to_slope(complete_landscape))
        NK_fourier.append(landscape_r2(complete_landscape))
        NK_paths_to_max.append(get_mean_paths_to_max(complete_landscape, norm=False, extra_slack = 0))
        NK_closest_max.append(1/find_distance_to_closest_max(complete_landscape))
        NK_local_epistasis.append(local_epistasis(complete_landscape,[0,]*12)), 
        cr, extr = get_convergence_rate(r, N, K)
        convergence_rates.append(cr)
        convergence_rates_extra.append(extr)

12 0


100%|██████████| 10/10 [00:36<00:00,  3.63s/it]


12 1


100%|██████████| 10/10 [00:36<00:00,  3.64s/it]


12 2


100%|██████████| 10/10 [00:37<00:00,  3.74s/it]


12 3


100%|██████████| 10/10 [00:36<00:00,  3.67s/it]


12 4


100%|██████████| 10/10 [00:36<00:00,  3.64s/it]


12 5


100%|██████████| 10/10 [00:36<00:00,  3.70s/it]


12 6


100%|██████████| 10/10 [00:36<00:00,  3.64s/it]


12 7


100%|██████████| 10/10 [00:36<00:00,  3.66s/it]


12 8


100%|██████████| 10/10 [00:35<00:00,  3.51s/it]


12 9


100%|██████████| 10/10 [00:35<00:00,  3.50s/it]


12 10


100%|██████████| 10/10 [00:34<00:00,  3.46s/it]


12 11


100%|██████████| 10/10 [00:34<00:00,  3.45s/it]


In [125]:
# Reshaping
NK_roughness_to_slope_r = np.array(NK_roughness_to_slope).reshape(12,10)
NK_fourier_r = np.array(NK_fourier).reshape(12,10)
convergence_rates_r = np.array(convergence_rates).reshape(12,10)
NK_paths_to_max_r = np.array(NK_paths_to_max).reshape(12,10)
NK_closest_max_r = np.array(NK_closest_max).reshape(12,10)
NK_local_epistasis_r = np.array([i['simple_sign_episasis'] + i['reciprocal_sign_epistasis'] for i in NK_local_epistasis]).reshape(12,10)


In [126]:
mean_NK_roughness_to_slope = [np.array(i).mean() for i in NK_roughness_to_slope_r]
mean_NK_fourier = [np.array(i).mean() for i in NK_fourier_r]
mean_convergence_rates = [np.array(i).mean() for i in  convergence_rates_r]
#mean_convergence_rates_extra = [np.array(i).mean() for i in convergence_rates_extra]
mean_NK_paths_to_max = [np.array(i).mean() for i in NK_paths_to_max_r]
mean_NK_closest_max = [np.array(i).mean() for i in NK_closest_max_r]
mean_NK_local_epistasis = [np.array(i).mean() for i in NK_local_epistasis_r]

In [127]:
mean_NK_closest_max = 1 - 1/np.array(mean_NK_closest_max)

In [128]:
def norm_data(data):
    return np.array(( data- min(data) )/( max(data)-min(data) ))

mean_NK_roughness_to_slope = norm_data(np.array(mean_NK_roughness_to_slope))
mean_NK_fourier = norm_data(np.array(mean_NK_fourier))
mean_NK_convergence_rates = norm_data(np.array(mean_convergence_rates))
mean_NK_paths_to_max = norm_data(np.array(mean_NK_paths_to_max))
mean_NK_closest_max = norm_data(np.array(mean_NK_closest_max))
mean_NK_le_normed = norm_data(np.array(mean_NK_local_epistasis))


k_over_ns = (test_NKs[:,1]+1)/test_NKs[:,0]

In [129]:
with open('plot_data/NK_ruggedness_metric_comparison.pkl', 'wb') as f:
    pickle.dump((mean_NK_roughness_to_slope, mean_NK_fourier, mean_NK_convergence_rates, mean_NK_paths_to_max, mean_NK_closest_max, k_over_ns, mean_NK_le_normed), f) 

## Empirical ruggedness prediction

### Comparison of ruggedness metrics on real landscapes

In [130]:
with open('landscape_arrays/GB1_landscape_array.pkl', 'rb') as f:
    GB1 = pickle.load(f)

with open('landscape_arrays/E3_landscape_array.pkl', 'rb') as f:
    ParD3 = pickle.load(f)

with open('landscape_arrays/TEV_landscape_array.pkl', 'rb') as f:
    TEV = pickle.load(f)

with open('landscape_arrays/TrpB_landscape_array.pkl', 'rb') as f:
    TrpB = pickle.load(f)

In [131]:
empirical_landscapes = [GB1, TrpB, TEV, ParD3]

In [132]:
# Decay rate

with open(os.path.join(slide_data_dir, "decay_curves_gb1_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    gb1_decay = pickle.load(f)
with open(os.path.join(slide_data_dir, "decay_curves_trpb_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    trpb_decay = pickle.load(f)
with open(os.path.join(slide_data_dir, "decay_curves_tev_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    tev_decay = pickle.load(f)
with open(os.path.join(slide_data_dir, "decay_curves_pard3_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    pard3_decay = pickle.load(f)

decay_rate_measurements = [(i**2).mean(axis=(0,1,2)) for i in [gb1_decay, trpb_decay, tev_decay, pard3_decay]]
decay_rate_measurements = [i/i[0] for i in decay_rate_measurements]
decay_rate_measurements = [get_single_decay_rate(i)[0]/2 for i in decay_rate_measurements]

In [133]:
# Roughness to slope
roughness_to_slope_measurements  = [roughness_to_slope(i) for i in empirical_landscapes]

In [134]:
# Landscape R2
landscape_r2_measurements = [1-landscape_r2(i) for i in empirical_landscapes]

In [135]:
# Local epistasis
starting_points = [[3,17,0,3],[3,8,3,18],[19,17,11,18],[17,12,16]]
landscapes_and_starts = list(zip(empirical_landscapes,starting_points))
local_epistasis_measurements = [local_epistasis(i, np.array(ii)) for i, ii in landscapes_and_starts]

In [136]:
# Paths to global maximum
paths_to_max_measurements = [get_mean_paths_to_max(i, norm=False) for i in empirical_landscapes]
paths_to_max_measurements[3] = paths_to_max_measurements[3]*(max_possible_paths(GB1.shape)/max_possible_paths(ParD3.shape))

In [137]:
# Distance to closest local maximum
local_max_measurements = [find_distance_to_closest_max(i) for i in empirical_landscapes]

In [138]:
with open('plot_data/empirical_ruggedness_metric_comparison.pkl', 'wb') as f:
    pickle.dump((decay_rate_measurements, roughness_to_slope_measurements, landscape_r2_measurements, local_epistasis_measurements, paths_to_max_measurements, local_max_measurements), f) 

### Fourier spectra of empirical landscapes

In [139]:
full_spectrum_gb1 = get_landscape_spectrum(GB1, remove_constant= False, on_gpu=True, norm = True)
full_spectrum_trpb = get_landscape_spectrum(TrpB, remove_constant= False, on_gpu=True, norm = True)
full_spectrum_tev = get_landscape_spectrum(TEV, remove_constant= False, on_gpu=True, norm = True)
full_spectrum_pard3 = get_landscape_spectrum(ParD3, remove_constant= False, on_gpu=True, norm = True)

with open('plot_data/fourier_spectra_empirical.pkl', 'wb') as f:
    pickle.dump((full_spectrum_gb1 , full_spectrum_trpb, full_spectrum_tev, full_spectrum_pard3),f)

### Sub-sampling accuracy

In [140]:
flat_data = [d.reshape(-1, 25) for d in [gb1_decay, trpb_decay, tev_decay, pard3_decay]]

In [141]:
def generate_sq_subsample(rng, flat_data, num_samples=100, num_reps = 100):
    """
    Generate a subsample of the squared data for plotting.
    """
    
    def single_sample(rng, flat_data):
        """
        Generate a single sample of the squared data.
        """
        idx = jr.choice(rng, jnp.arange(flat_data.shape[0]), shape=(num_samples,), replace=False)
        sq_data = jnp.mean(flat_data[idx]**2, axis = 0)
        normed_data = sq_data - sq_data[-1]
        normed_data = normed_data / normed_data[0]
        return normed_data
    

    rngs = jr.split(rng, num_reps)
    samples = jax.vmap(jax.jit(single_sample), in_axes=(0, None))(rngs, flat_data)
    return samples

In [142]:
rng=jr.PRNGKey(0)

bad_subsampy = generate_sq_subsample(rng, flat_data[0], num_samples=10, num_reps=200)
subsampy = generate_sq_subsample(rng, flat_data[0], num_samples=100, num_reps=200)
good_subsampy = generate_sq_subsample(rng, flat_data[0], num_samples=1000, num_reps=200)

bad_curve_params = jnp.array([get_single_decay_rate(s) for s in bad_subsampy])
curve_params = jnp.array([get_single_decay_rate(s) for s in subsampy])
good_curve_params = jnp.array([get_single_decay_rate(s) for s in good_subsampy])

In [143]:
# Takes 30s
num_reps_used = 200
num_samples_totest = np.logspace(1,3, num=8, dtype=int)

results_means = []
results_stds = []
for d in flat_data:
    results_l_m = []
    results_l_s = []
    for num_samples in num_samples_totest:
        subsampy = generate_sq_subsample(rng, d, num_samples=num_samples, num_reps=num_reps_used)
        curve_params = jnp.array([get_single_decay_rate(s) for s in subsampy])
        results_l_m.append(curve_params[:,0].mean())
        results_l_s.append(jnp.std(curve_params[:,0], axis=0))

    results_means.append(results_l_m)
    results_stds.append(results_l_s)

results_means = jnp.array(results_means)/2
results_stds = jnp.array(results_stds)/2

In [144]:
data_names = ['GB1', 'TrpB', 'TEV', 'ParD3']

with open('plot_data/subsampling_empirical_estimates.pkl', 'wb') as f:
    pickle.dump((num_samples_totest, results_means, results_stds, data_names), f)

## Optimising directed evolution

### Optimal DE strategies from sweep

In [145]:
## Taking the mean over N.
reshaped_strategies = strategy_data.reshape(100, -1, 300)
N_meaned_strategies = reshaped_strategies[:90].mean(axis=2).reshape(9,10,49).mean(axis=0)
N_meaned_decay_data = normalized_decay_curves.reshape(10,10,250,25).mean(axis=(0,2))

In [146]:
decay_rates = []

for i in N_meaned_decay_data:
    decay_rates.append(get_single_decay_rate(i,mut=0.5)[0])
    
decay_rates = np.array(decay_rates)

In [147]:
## Getting optimal base chances and splits from strategy sweep.

thresholds, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 7)
splits = [24,20,16,12,8,4,1]

base_chance_array = np.array([base_chances]*7).flatten()
splitting_array = np.array([[24]*7,[20]*7,[16]*7,[12]*7,[8]*7,[4]*7,[1]*7]).flatten()

optimal_splits = []
optimal_base_chances = []

for i in range(10):
    max_val = N_meaned_strategies[i].argmax()
    optimal_splits.append(splitting_array[max_val])
    optimal_base_chances.append(base_chance_array[max_val])

In [148]:
with open('plot_data/optimal_DE_strategies.pkl', 'wb') as f:
    pickle.dump((decay_rates, optimal_splits, optimal_base_chances), f)

### Accuracy of strategy prediction

In [149]:
NKs = np.array(NKs)
k_over_ns = (NKs[:,1]+1)/NKs[:,0]

In [150]:
## Predictions from 10 isolated runs

predicted_base_chances = []
predicted_splittings = []
actual_k_over_ns = []

for landscape in range(normalized_decay_curves.shape[0]):

    landscape_outcomes = []

    # For each ld, take mean across all reps and reshape into strategy space.
    landscape_strategy_outcomes = reshaped_strategies[landscape,:,:].mean(axis=1).reshape(7,7)

    for run in range(100):

        ### Estimate decay rate (k/n).
        run_data = normalized_decay_curves[landscape, run,:]
        run_decay_rate = get_single_decay_rate(run_data,mut=0.5)[0]
        actual_k_over_ns.append(k_over_ns[landscape])

        ### Collect predicted vs actual 
        predicted_base_chance = optimal_base_chances[np.argmin(np.abs(decay_rates - run_decay_rate))]
        predicted_splitting = optimal_splits[np.argmin(np.abs(decay_rates - run_decay_rate))]
        predicted_base_chances.append(predicted_base_chance)
        predicted_splittings.append(predicted_splitting)

In [151]:
rounded_actual = np.round(actual_k_over_ns, 1)

unique_x = np.unique(rounded_actual)

bc_means = []
bc_stds = []

for x in unique_x:
    # Select points corresponding to the current x value
    bc_values = np.array(predicted_base_chances)[rounded_actual == x]
    bc_means.append(np.mean(bc_values))
    bc_stds.append(np.std(bc_values))

sp_means = []
sp_stds = []

for x in unique_x:
    # Select points corresponding to the current x value
    sp_values = np.array(predicted_splittings)[rounded_actual == x]
    sp_means.append(np.mean(sp_values))
    sp_stds.append(np.std(sp_values))

In [152]:
with open('plot_data/strategy_prediction_accuracy.pkl', 'wb') as f:
    pickle.dump((actual_k_over_ns, bc_means, bc_stds, sp_means, sp_stds), f)

### NK directed evo

In [153]:
## Get predicted base chance and splitting values from N = 45, K = [1,25,45]

NK_samples = [(45,1), (45,25),(45,1), (45,25),]
indexes_of_interest = [1900,1400,1000]
NK_bc_predictions = [0.0,0.0]
NK_th_predictions = [0.8,0.8]
NK_sp_predictions = [1,1]

for i in indexes_of_interest:

    mean_bc_pred = np.array(predicted_base_chances[i:i+100]).mean()
    mean_sp_pred = np.array(predicted_splittings[i:i+100]).mean()

    # Get the closest to the standard values.
    bc = base_chances[np.argmin(np.abs(base_chances - mean_bc_pred))]
    th = thresholds[np.argmin(np.abs(base_chances - mean_bc_pred))]
    sp = splits[np.argmin(np.abs(splits - mean_sp_pred))]

    NK_bc_predictions.append(bc)
    NK_th_predictions.append(th)
    NK_sp_predictions.append(sp)

In [154]:
# This version of the DE function is modified to allow splits (by ensuring each split begins from the same starting location).

def directedEvolution(s_rng, 
                      rng_rep,
                      selection_strategy, 
                      selection_params, 
                      empirical = False, 
                      N = None, 
                      K = None, 
                      landscape = None, 
                      popsize=100, 
                      mut_chance=0.01, 
                      num_steps=50, 
                      num_reps=10, 
                      define_i_pop=None, 
                      average=True):
    
 
    # Get initial population.
    if define_i_pop == None:
        i_pop = jnp.array([jr.randint(rng_rep, (N,), 0, 2)]*popsize)
    else:
        i_pop = define_i_pop
 
    # Function for evaluating fitness.
    if empirical:
        fitness_function = build_empirical_landscape_function(landscape)
        mutation_function = build_mutation_function(mut_chance, 20)
    else:
        fitness_function = build_NK_landscape_function(rng_rep, N, K)
        mutation_function = build_mutation_function(mut_chance, 2)
 
    # Define selection function.
    selection_function = build_selection_function(
        selection_strategy, selection_params)
 
    # Bringing it all together.
    vmapped_run = jax.jit(jax.vmap(lambda r: run_directed_evolution(
        r, i_pop, selection_function, mutation_function, fitness_function=fitness_function, num_steps=num_steps)[1]))
    
    # The array of seeds we will take as input.
    rng_seeds = jr.split(s_rng, num_reps)
    results = vmapped_run(rng_seeds)
 
    return results

In [155]:
three_NK_results = []

for i in range(4):

    N,K = NK_samples[i]
    bc = NK_bc_predictions[i]
    s = NK_sp_predictions[i]
    th = NK_th_predictions[i]
    p=1200
    m=0.1
    params = {'base chance':bc, 'threshold': th}

    print('N: ', N, ', K: ', K, ', bc: ', bc, ', sp: ', s)

    rep_rngs = jr.split(jr.PRNGKey(42),100)

    def single_rep(rng_rep):

        split_rngs = jr.split(rng_rep, s)

        def single_s(s_rng):
            params = {'threshold': th, 'base_chance' : bc}
            run = directedEvolution(s_rng,
                                    rng_rep,
                                    N = int(N),
                                    K=int(K),
                                    selection_strategy=slct.base_chance_threshold_select,
                                    selection_params = params,
                                    popsize=int(p/s),
                                    mut_chance=m/int(N),
                                    num_steps=50,
                                    num_reps=1,
                                    average=True)

            #split_results.append(run['fitness'].max(axis=2).mean(axis=0)[-1])
            ##return run['fitness'][:,:,-1].max(axis=1).mean()
            return run['fitness'][:,:,:].mean(axis=0)

        split_results = jax.vmap(single_s)(split_rngs)

        return jnp.array(split_results)

    repeat_results = jax.vmap(single_rep)(rep_rngs)

    three_NK_results.append(repeat_results)

N:  45 , K:  1 , bc:  0.0 , sp:  1
N:  45 , K:  25 , bc:  0.0 , sp:  1
N:  45 , K:  1 , bc:  0.0 , sp:  4
N:  45 , K:  25 , bc:  0.0 , sp:  24


In [156]:
## Extracting winning splits in post
new_output = []
for i in three_NK_results:
    winning_splits_only = []
    for ii in range(i.shape[0]): # for each rep
        rep = i[ii]
        final_max = rep[:,-1,:].max(axis=1)
        winning_split = np.argmax(final_max)
        winning_splits_only.append(rep[winning_split])
    new_output.append(np.array(winning_splits_only).mean(axis=(0,2)))


In [157]:
with open('plot_data/NK_DE.pkl', 'wb') as f:
    pickle.dump(new_output, f)

### Strategy spaces

In [158]:
reshaped_strategies = strategy_data.reshape(100, -1, 300)
with open('plot_data/NK_strategy_spaces.pkl', 'wb') as f:
    pickle.dump((reshaped_strategies[19,:,:],reshaped_strategies[14,:,:]), f)

# Directed evolution on real landscapes

In [159]:
with open(os.path.join(slide_data_dir, "N4A20_strategy_sweep.pkl"), "rb") as f:
    strategy_data = pickle.load(f)

with open(os.path.join(slide_data_dir, "N4A20_decay_curves.pkl"), "rb") as f:
    decay_data = pickle.load(f)

In [160]:
decay_rates = []
decay_means = []
for n, i in enumerate(decay_data[1:]):
    decay_mean = (i**2).mean(axis=(0,1,2))
    decay_mean = decay_mean / decay_mean[0]
    decay_means.append(decay_mean)
    decay_rate = get_single_decay_rate(decay_mean,mut=0.1)
    decay_rates.append(decay_rate[0]/2)

optimal_pos = []
for n, i in enumerate(np.array(strategy_data).mean(axis=0)):

    # Find the index of the maximum value
    max_pos = np.unravel_index(np.argmax(i.T), i.shape)
    optimal_pos.append(max_pos)


In [161]:
## Saving data
with open('plot_data/empirical_lookup.pkl', 'wb') as f:
    pickle.dump((decay_rates, decay_means, optimal_pos, np.array(strategy_data).mean(axis=0)), f)

In [162]:
thresholds, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 7)
splits = [24,20,16,12,8,4,1]

optimal_base_chances = [base_chances[i[1]] for i in optimal_pos]
optimal_splits = [splits[i[0]] for i in optimal_pos]

In [163]:
# Loading all files

with open(os.path.join(slide_data_dir, "decay_curves_gb1_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    GB1_decay_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "decay_curves_trpb_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    TrpB_decay_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "decay_curves_tev_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    TEV_decay_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "decay_curves_pard3_m0.1_multistart_10000_uniform.pkl"), "rb") as f:
    ParD3_decay_multi = pickle.load(f)


with open(os.path.join(slide_data_dir, "strategy_sweep_GB1_multistart_100_uniform_m0.025.pkl"), "rb") as f:
    GB1_sweep_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "strategy_sweep_TrpB_multistart_100_uniform_m0.025.pkl"), "rb") as f:
    TrpB_sweep_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "strategy_sweep_TEV_multistart_100_uniform_m0.025.pkl"), "rb") as f:
    TEV_sweep_multi = pickle.load(f)

with open(os.path.join(slide_data_dir, "strategy_sweep_E3_multistart_100_uniform_m0.025.pkl"), "rb") as f:
    ParD3_sweep_multi = pickle.load(f)

In [164]:
def empirical_strategy_selection(decay, sweep, decay_rates, optimal_base_chances, optimal_splits, N=4):

    def strategy_from_decay(decay_rate, standard_decay_rates = decay_rates, optimal_base_chances = optimal_base_chances, optimal_splits = optimal_splits):
        
        optimum_index = np.argmin(np.abs(standard_decay_rates - decay_rate))
        
        return optimal_base_chances[optimum_index], optimal_splits[optimum_index]

    # Compute data needed for all subplots
    #decay_mean = decay.mean(axis=(0,1))
    decay_mean = (decay**2).mean(axis=(0,1,2))
    decay_mean = decay_mean / decay_mean[0]
    decay_rate = get_single_decay_rate(decay_mean,mut=0.1)
    x_vals = np.linspace(0, 24, 25)
    strategy = strategy_from_decay(decay_rate[0]/2)

    if N==3:
        scipy_freq_matrix = np.zeros((5, 5), dtype=int)
        _, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 5)
        splits = [20,15,10,5,1]
    else:
        scipy_freq_matrix = np.zeros((7, 7), dtype=int)
        _, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 7)
        splits = [24,20,16,12,8,4,1]

    j = np.where(np.array(base_chances) == strategy[0])[0]
    i = np.where(np.array(splits) == strategy[1])[0]
    scipy_freq_matrix[i,j] = 1

    return(x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, strategy)

def uniform_start_locs(ld, num=10000):
    flat_ld = ld.flatten()
    flat_indexes = np.round(np.linspace(0, flat_ld.shape[0]-1, num)).astype(int)
    indexes = np.array([np.unravel_index(i, ld.shape) for i in flat_indexes])
    return indexes

def test_strategy_empirical(ld, bcs, sps, ths, starts):

    ld = jnp.array(ld)

    start_results = []

    for start in tqdm.tqdm(starts):

        results = []

        for i in range(2):

            bc = bcs[i]
            s = sps[i]
            th = ths[i]
            if len(ld.shape) == 4:
                p=1200
            if len(ld.shape) == 3:
                p=60
            m=0.01

            rep_rngs = jr.split(jr.PRNGKey(42),100)

            def single_rep(rng_rep):

                split_rngs = jr.split(rng_rep, s)

                def single_s(s_rng):
                    params = {'threshold': th, 'base_chance' : bc}
                    
                    run = directedEvolution(s_rng,
                                            rng_rep,
                                            selection_strategy=slct.base_chance_threshold_select,
                                            selection_params = params,
                                            popsize=int(p/s),
                                            mut_chance=m,
                                            num_steps=80,
                                            num_reps=1,
                                            define_i_pop=jnp.array([start]*int(p/s)),
                                            empirical=True,
                                            landscape=ld,
                                            average=True)

                    return run['fitness'][:,:,:].mean(axis=0)

                split_results = jax.vmap(single_s)(split_rngs)

                return jnp.array(split_results)

            repeat_results = jax.vmap(single_rep)(rep_rngs)

            results.append(repeat_results)

        ## Extracting winning splits in post
        run = []
        for i in results:
            winning_splits_only = []
            for ii in range(i.shape[0]): # for each rep
                rep = i[ii]
                final_max = rep[:,-1,:].max(axis=1)
                winning_split = np.argmax(final_max)
                winning_splits_only.append(rep[winning_split])
            run.append(np.array(winning_splits_only).mean(axis=(0,2)))

        start_results.append(run)

    return start_results

In [165]:
## GB1

x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, strategy = empirical_strategy_selection(GB1_decay_multi, GB1_sweep_multi, decay_rates, optimal_base_chances, optimal_splits)

run = test_strategy_empirical(GB1, [0.0,strategy[0]], 
                        [1,strategy[1]], 
                        [0.8,thresholds[np.array(base_chances) == strategy[0]]],
                        starts=uniform_start_locs(ld=GB1, num=10))

## Saving data
with open('plot_data/GB1_strategy_selection.pkl', 'wb') as f:
    pickle.dump((x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, run), f)

100%|██████████| 10/10 [00:28<00:00,  2.86s/it]


In [166]:
## TrpB

x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, strategy = empirical_strategy_selection(TrpB_decay_multi, TrpB_sweep_multi, decay_rates, optimal_base_chances, optimal_splits)

run = test_strategy_empirical(TrpB, [0.0,strategy[0]], 
                        [1,strategy[1]], 
                        [0.8,thresholds[np.array(base_chances) == strategy[0]]],
                        starts=uniform_start_locs(ld=TrpB, num=10))

## Saving data
with open('plot_data/TrpB_strategy_selection.pkl', 'wb') as f:
    pickle.dump((x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, run), f)

100%|██████████| 10/10 [00:27<00:00,  2.73s/it]


In [167]:
## TEV

x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, strategy = empirical_strategy_selection(TEV_decay_multi, TEV_sweep_multi, decay_rates, optimal_base_chances, optimal_splits)

run = test_strategy_empirical(TEV, [0.0,strategy[0]], 
                        [1,strategy[1]], 
                        [0.8,thresholds[np.array(base_chances) == strategy[0]]],
                        starts=uniform_start_locs(ld=TEV, num=10))

## Saving data
with open('plot_data/TEV_strategy_selection.pkl', 'wb') as f:
    pickle.dump((x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, run), f)

100%|██████████| 10/10 [00:27<00:00,  2.79s/it]


In [168]:
with open(os.path.join(slide_data_dir, "N3A20_strategy_sweep.pkl"), "rb") as f:
    strategy_data = pickle.load(f)

with open(os.path.join(slide_data_dir, "N3A20_decay_curves.pkl"), "rb") as f:
    decay_data = pickle.load(f)

In [169]:
decay_rates = []
decay_means = []
for n, i in enumerate(decay_data[1:]):
    decay_mean = (i**2).mean(axis=(0,1,2))
    decay_mean = decay_mean / decay_mean[0]
    decay_means.append(decay_mean)
    decay_rate = get_single_decay_rate(decay_mean,mut=0.1)
    decay_rates.append(decay_rate[0]/2)

optimal_pos = []
for n, i in enumerate(np.array(strategy_data).mean(axis=0)):

    # Find the index of the maximum value
    max_pos = np.unravel_index(np.argmax(i.T), i.shape)
    optimal_pos.append(max_pos)

In [170]:
thresholds, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 5)

In [171]:
## ParD3

thresholds, base_chances = base_chance_threshold_fixed_prop([0,0.19], 0.2, 5)
splits = [20,15,10,5,1]

optimal_base_chances = [base_chances[i[1]] for i in optimal_pos]
optimal_splits = [splits[i[0]] for i in optimal_pos]

x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, strategy = empirical_strategy_selection(ParD3_decay_multi, ParD3_sweep_multi, decay_rates, optimal_base_chances, optimal_splits,N=3)

run = test_strategy_empirical(ParD3, [0.0,strategy[0]], 
                        [1,strategy[1]], 
                        [0.8,thresholds[np.array(base_chances) == strategy[0]]],
                        starts=uniform_start_locs(ld=ParD3, num=10))

## Saving data
with open('plot_data/ParD3_strategy_selection.pkl', 'wb') as f:
    pickle.dump((x_vals, decay_mean, decay_rate, sweep, scipy_freq_matrix, run), f)

100%|██████████| 10/10 [00:26<00:00,  2.66s/it]
