In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import pickle
import seaborn as sns
import time
from matplotlib.colors import LinearSegmentedColormap
from scipy import stats
import scipy.optimize

In [3]:
# simulation functions
def update(state, fitness, next_gen, burst_size, mut_prob):
    num_replicating = math.ceil(next_gen/burst_size)
    rep_probs = np.multiply(state, fitness)
    rep_probs = rep_probs/np.sum(rep_probs)
    rep_viruses = np.random.choice(len(fitness), size=num_replicating, replace=True, p=rep_probs)
    unique, counts = np.unique(rep_viruses, return_counts=True)

    to_select, select_probs = np.unique(rep_viruses, return_counts=True)
    survivors = np.random.choice(to_select, size=int(next_gen), replace=True, p=select_probs/np.sum(select_probs))
    
    unique, counts = np.unique(survivors, return_counts=True)
    mutators = np.random.choice(survivors, size=np.random.binomial(next_gen, mut_prob), replace=False)
    
    unique, counts = np.unique(mutators, return_counts=True)
    to_return = []
    for i in range(len(fitness)):
        new_pop = np.sum(survivors == i)
        new_pop += np.sum(mutators == (i-1))
        new_pop -= np.sum(mutators == i)
        to_return.append(new_pop)
    return(np.array(to_return))

def simulate(viral_load_curve, fitness_cost, fitness_benefit, num_mut, mut_prob, burst_size=1e3):
    fitnesses = np.zeros(num_mut+1) + 1
    if num_mut > 1:
        fitnesses[1:(num_mut-1)] -= fitness_cost
    fitnesses[-1] += fitness_benefit
    curr_state = np.zeros(num_mut+1)
    curr_state[0] = viral_load_curve[0]
    all_data = np.zeros((len(viral_load_curve), num_mut+1))
    all_data[0,:] = curr_state
    for i in range(1,len(viral_load_curve)):
        load = viral_load_curve[i]
        curr_state = update(curr_state, fitnesses, load, burst_size, mut_prob)
        all_data[i,:] = curr_state
    return(all_data)

In [4]:
# functions for processing and plotting simulation data
def compute_fracs(data, log_data=True, CI=None, median=False):
    all_data = np.array(data)
    
    kinetics = np.sum(data[0], axis=1)
    kinetics = np.reshape(kinetics, (len(kinetics),1))
    all_data = all_data/kinetics
    if log_data:
        all_data = np.log10(all_data)
        all_data[all_data == -np.inf] = -7
    
    if median:
        means = np.quantile(all_data, 0.5, axis=0)
    else:
        means = np.nanmean(all_data, axis=0)
    if CI is None:
        sem = stats.sem(means, nan_policy="omit")
        lowers = means - sem
        uppers = means + sem
    else:
        lowers = np.quantile(all_data, (1-CI)/2, axis=0)
        uppers = np.quantile(all_data, 1-(1-CI)/2, axis=0)
    return([means, lowers, uppers])

def transmit_probs(data, end_time=None, num_trans=100):
    all_data = np.array(data)
    
    if not end_time is None:
        all_data = all_data[:,0:int(end_time*2),:]
    total_load = np.sum(all_data[0,:,:])
    
    pt_probs = np.sum(all_data, axis=1)/total_load
    pt_probs = 1-np.power((1-pt_probs), num_trans)
    
    return(pt_probs)

def max_freq(data):
    all_data = np.array(data)
    
    kinetics = np.sum(data[0], axis=1)
    kinetics = np.reshape(kinetics, (len(kinetics),1))
    all_data = all_data/kinetics

    maxes = np.quantile(all_data, 1, axis=1)
    return(np.mean(maxes, axis=0))

In [None]:
#data from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7734137/

long_lower_IQR = 86
long_upper_IQR = 101.5
est_SD = (long_upper_IQR-long_lower_IQR)/1.35
est_mean = (long_upper_IQR+long_lower_IQR)/2

from scipy.stats import norm
total_lens = np.array(extra_weeks) * 7 + 23-5
cdf_vals = norm.cdf(total_lens, loc=est_mean, scale=est_SD)
long_dist = cdf_vals[1:] - cdf_vals[:-1]

## Figure 1

In [None]:
# simulations for Figure 1C-E
viral_load_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8, 1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
fitness_benefits = [-0.01, -0.05, 0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
viral_load_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8, 1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
viral_load_adj = np.array(viral_load_kinetics)/1000

num_iter = 1000

for benefit in fitness_benefits:
    print(benefit)
    all_res = []
    for i in range(num_iter):
        all_res.append(simulate(viral_load_adj, None, benefit, 1, 1e-5, burst_size=1))
    pickle.dump(all_res, open("./results/standard_"+str(benefit)+".p", "wb" ))

In [None]:
# data loading for Figure 1C-E
fitness_benefits = [-0.05, -0.01, 0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]

means = []
lowers = []
uppers = []
early_probs = []
mid_probs = []
midlate_probs = []
late_probs = []
for benefit in fitness_benefits:
    data = pickle.load(open("./results/standard_"+str(benefit)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    early_probs.append(transmit_probs(data, 3, num_trans=10))
    mid_probs.append(transmit_probs(data, 5, num_trans=10))
    midlate_probs.append(transmit_probs(data, 7, num_trans=10))
    late_probs.append(transmit_probs(data, 23, num_trans=10))
    
to_plot = pd.DataFrame({"benefit":fitness_benefits, "day3":[np.mean(x[:,1]) for x in early_probs], "day5":[np.mean(x[:,1]) for x in mid_probs], "day7":[np.mean(x[:,1]) for x in midlate_probs], 'any':[np.mean(x[:,1]) for x in late_probs]})
prob_SEMs = pd.DataFrame({"benefit":fitness_benefits, "day3":[stats.sem(x[:,1]) for x in early_probs], "day5":[stats.sem(x[:,1]) for x in mid_probs], "day7":[stats.sem(x[:,1]) for x in midlate_probs], 'any':[stats.sem(x[:,1]) for x in late_probs]})
to_plot.sort_values(by="benefit", inplace=True)
prob_SEMs.sort_values(by="benefit", inplace=True)

In [None]:
# plotting Fig. 1C
N_color = LinearSegmentedColormap.from_list(colors=["purple", "cyan"], name="N")

xplot = np.arange(np.shape(data[0])[0])/2
mut_index = 1

plt.figure()

for i in range(len(fitness_benefits)):
    plt.plot(xplot, means[i][:,mut_index], color=N_color((fitness_benefits[i]+0.05)/.55))
    plt.fill_between(xplot, lowers[i][:,mut_index], uppers[i][:,mut_index], color=N_color((fitness_benefits[i]+0.05)/.55), alpha=0.3)
    
#plt.yscale("log")
plt.ylabel("mean intrahost\nfrequency of variant", fontsize=14)
plt.xlabel("days since infection", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

In [None]:
# plotting Fig. 1D
plt.figure()
plt.plot(to_plot["benefit"], to_plot["day3"], color="lightcoral")
plt.fill_between(to_plot["benefit"], to_plot["day3"]-prob_SEMs["day3"], to_plot["day3"]+prob_SEMs["day3"], color="lightcoral", alpha=0.1)
plt.plot(to_plot["benefit"], to_plot["day5"], color="indianred")
plt.fill_between(to_plot["benefit"], to_plot["day5"]-prob_SEMs["day5"], to_plot["day5"]+prob_SEMs["day5"], color="indianred", alpha=0.1)
plt.plot(to_plot["benefit"], to_plot["day7"], color="firebrick")
plt.fill_between(to_plot["benefit"], to_plot["day7"]-prob_SEMs["day7"], to_plot["day7"]+prob_SEMs["day7"], color="firebrick", alpha=0.1)
plt.plot(to_plot["benefit"], to_plot["any"], color="maroon")
plt.fill_between(to_plot["benefit"], to_plot["any"]-prob_SEMs["any"], to_plot["any"]+prob_SEMs["any"], color="maroon", alpha=0.1)
plt.ylabel("probability of passing on variant", fontsize=14)
plt.xlabel("fitness effect of mutation", fontsize=14)
#plt.yscale("log")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

In [None]:
# plotting Fig. 1E
N = 5e4
l = 14

lams = [1.1, 1.2, 1.3, 1.4, 1.5]
est_prob = []

for lam in lams:
    def find_pext(x):
        return(1-np.exp(-lam*x) - x)
    p_surv = scipy.optimize.broyden1(find_pext, 0.1, f_tol=1e-14)
    est_prob.append(to_plot["day7"]*p_surv*N)

plt.figure()
sns.heatmap(est_prob, cmap="Reds").invert_yaxis()
plt.ylabel("$R_0$ of new mutant", fontsize=14)
plt.xlabel("fitness effect of mutation", fontsize=14)
#plt.yscale("log")
plt.xticks(ticks=np.arange(10) + 0.5, labels=fitness_benefits, fontsize=12)
plt.yticks(ticks=np.arange(5) + 0.5, labels=lams, fontsize=12)
#sns.despine()
plt.tight_layout()
#plt.savefig(".eps", transparent=True)
plt.show()

## Figure 2

In [None]:
# simulations for Fig. 2A
left_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8]
right_kinetics = [1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
extra_days = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
benefit = 0.2

num_iter = 1000

for time in extra_days:
    print(time)
    all_res = []
    new_viral_load = left_kinetics + [1e9]*(time*2) + right_kinetics
    new_viral_load = np.array(new_viral_load)/1000
    for i in range(num_iter):
        all_res.append(simulate(new_viral_load, None, benefit, 1, 1e-5, burst_size=1))
    pickle.dump(all_res, open("./results/single_long_"+str(time)+".p", "wb" ))

In [None]:
# plotting Fig. 2A
extra_days = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_days:
    data = pickle.load(open("./results/single_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data))
    
to_plot = pd.DataFrame({"days":extra_days, 'any':[np.mean(x[:,1]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"days":extra_days, 'any':[stats.sem(x[:,1]) for x in trans_probs]})
to_plot.sort_values(by="days", inplace=True)
prob_SEMs.sort_values(by="days", inplace=True)

plt.figure()
plt.plot(to_plot["days"]+23, to_plot["any"], color="black")
plt.fill_between(to_plot["days"]+23, to_plot["any"]-prob_SEMs["any"], to_plot["any"]+prob_SEMs["any"], color="black", alpha=0.1)
plt.ylabel("probability of passing on variant", fontsize=14)
plt.xlabel("infection length (days)", fontsize=14)
#plt.yscale("log")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

In [None]:
# simulations for Fig. 2B
viral_load_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8, 1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
viral_load_adj = np.array(viral_load_kinetics)/1000

efficacies = np.arange(4)
benefit = 0.2

num_iter = 1000

for eff in efficacies:
    print(eff)
    viral_load_treat = viral_load_adj/10**eff
    viral_load_treat = viral_load_treat[viral_load_treat >= 1]
    all_res = []
    for i in range(num_iter):
        all_res.append(simulate(viral_load_treat, None, benefit, 1, 1e-5, burst_size=1))
    pickle.dump(all_res, open("./results/treatment_"+str(eff)+".p", "wb" ))

In [None]:
# plotting Fig. 2B
efficacies = np.arange(4)

means = []
lowers = []
uppers = []
trans_probs = []
for eff in efficacies:
    data = pickle.load(open("./results/treatment_"+str(eff)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data))
    
to_plot = pd.DataFrame({"eff":efficacies, 'any':[np.mean(x[:,1]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"eff":efficacies, 'any':[stats.sem(x[:,1]) for x in trans_probs]})
to_plot.sort_values(by="eff", inplace=True)
prob_SEMs.sort_values(by="eff", inplace=True)

plt.figure()
plt.plot(to_plot["eff"], to_plot["any"], color="black")
plt.fill_between(to_plot["eff"], to_plot["any"]-prob_SEMs["any"], to_plot["any"]+prob_SEMs["any"], color="black", alpha=0.1)
plt.ylabel("probability of passing on variant", fontsize=14)
plt.xlabel("fold reduction in viral load", fontsize=14)
#plt.yscale("log")
plt.xticks(ticks=efficacies, labels=10**efficacies, fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

## Figure 3

In [None]:
# simulations for Fig. 3- two mutation combination
left_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8]
right_kinetics = [1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
extra_weeks = [0, 1, 2, 3, 4, 5, 10, 15]
valley = 0.05
benefit = 0.2

num_iter = 1000

for time in extra_weeks:
    print(time)
    all_res = []
    new_viral_load = left_kinetics + [1e9]*(time*14) + right_kinetics
    new_viral_load = np.array(new_viral_load)/1000
    for i in range(num_iter):
        all_res.append(simulate(new_viral_load, valley, benefit, 2, 1e-5, burst_size=1))
    pickle.dump(all_res, open("./results/valley_long_"+str(time)+".p", "wb" ))

In [None]:
# simulations for Fig. 3- three mutation combination
left_kinetics = [1000, 1e3, 1e5, 1e5, 1e6, 1e6, 1e6, 1e7, 1e7, 1e7, 1e8, 1e8]
right_kinetics = [1e9, 1e9, 1e9, 1e9, 1e8, 1e8, 1e8, 1e8, 1e7, 1e7, 1e7, 1e7, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e4, 1e4, 1e4, 1e3]
extra_weeks = [0, 1, 2, 3, 4, 5, 10, 15]
valley = 0.05
benefit = 0.2

num_iter = 1000

for time in extra_weeks:
    print(time)
    all_res = []
    new_viral_load = left_kinetics + [1e9]*(time*14) + right_kinetics
    new_viral_load = np.array(new_viral_load)/1000
    for i in range(num_iter):
        all_res.append(simulate(new_viral_load, valley, benefit, 3, 1e-5, burst_size=1))
    pickle.dump(all_res, open("./results/valley_long_3_"+str(time)+".p", "wb" ))

In [None]:
# plotting Fig. 3B

extra_weeks = [0, 1, 2, 3, 4, 5, 10, 15]

means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/valley_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=10))
    
to_plot_hold = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,2]) for x in trans_probs]})
SEMS_hold = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,2]) for x in trans_probs]})
to_plot_hold.sort_values(by="week", inplace=True)
SEMS_hold.sort_values(by="week", inplace=True)
uppers_hold = uppers
lowers_hold = lowers
means_hold = means

means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/valley_long_3_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=10))
    
to_plot = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,3]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,3]) for x in trans_probs]})
to_plot.sort_values(by="week", inplace=True)
prob_SEMs.sort_values(by="week", inplace=True)

plt.figure()
mut_index = 3
i = 7
xplot = np.arange(len(means[i][:,mut_index]))/(2*7)
plt.plot(xplot, means[i][:,mut_index], color="maroon")
plt.fill_between(xplot, lowers[i][:,mut_index], uppers[i][:,mut_index], color="maroon", alpha=0.1)

mut_index = 2
i = 7
plt.plot(xplot, means_hold[i][:,mut_index], color="blue")
plt.fill_between(xplot, lowers_hold[i][:,mut_index], uppers_hold[i][:,mut_index], color="blue", alpha=0.1)
    
#plt.yscale("log")
plt.ylabel("mean intrahost\nfrequency of variant", fontsize=14)
plt.xlabel("weeks since infection started", fontsize=14)
plt.xticks(ticks = [0,2,4,6,8,10,12,14,16,18], fontsize=14)
plt.yticks(fontsize=14)

sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

In [None]:
# plotting Fig. 3C
plt.figure()
plt.plot(to_plot_hold["week"]+3, to_plot_hold["any"], color="blue")
plt.fill_between(to_plot_hold["week"]+3, to_plot_hold["any"]-SEMS_hold["any"], to_plot_hold["any"]+SEMS_hold["any"], color="blue", alpha=0.1)

plt.plot(to_plot["week"]+3, to_plot["any"], color="maroon")
plt.fill_between(to_plot["week"]+3, to_plot["any"]-prob_SEMs["any"], to_plot["any"]+prob_SEMs["any"], color="maroon", alpha=0.1)
plt.ylabel("probability of passing on variant", fontsize=14)
plt.xlabel("infection length (wks)", fontsize=14)
plt.yscale("log")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.ylim(1e-10, 2e-1)
sns.despine()
plt.tight_layout()
#plt.savefig(".pdf", transparent=True)
plt.show()

In [None]:
# plotting Fig. 3D
lam = 1.5
def find_pext(x):
    return(1-np.exp(-lam*x) - x)

N = 5e4
p_surv = scipy.optimize.broyden1(find_pext, 0.1, f_tol=1e-14)
long_nvar_2mut = np.sum(to_plot_hold["any"][1:] * long_dist * p_surv*N)
short_nvar_2mut = to_plot_hold["any"][0]

frac_long = [0.0001, 0.001, 0.01, 0.05]
combined_nvar = [short_nvar_2mut*(1-p) + long_nvar_2mut*p for p in frac_long]
label_percent = [str(x)+"%" for x in np.array(frac_long)*100]

plt.figure()
plt.bar(x=range(len(frac_long)), height=combined_nvar, fill="blue")
plt.xticks(ticks=range(len(frac_long)), labels=label_percent, fontsize=14)
plt.yticks(fontsize=14)
plt.ylabel("number of infections/day with\nnew two-mutation combination", fontsize=14)
plt.xlabel("long-term viral shedder frequency in population", fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".eps", transparent=True)
plt.show()

## Figure 4

In [None]:
# compiling data from previous simulations and plotting
# all previous simulations for the two mutation combination must be run prior to this

extra_weeks = [0, 1, 2, 3, 4, 5, 10, 15]

means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/valley_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=10))
    
to_plot = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,2]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,2]) for x in trans_probs]})
to_plot.sort_values(by="week", inplace=True)
prob_SEMs.sort_values(by="week", inplace=True)

N = 5e4
import scipy.optimize
lam = 1.5
def find_pext(x):
    return(1-np.exp(-lam*x) - x)

p_surv = scipy.optimize.broyden1(find_pext, 0.1, f_tol=1e-14)
long_nvar_2mut = np.sum(to_plot["any"][1:] * long_dist * p_surv*N)
short_nvar_2mut = to_plot["any"][0]

frac_long = [0.001, 0.001*0.1]
combined_nvar = [short_nvar_2mut*(1-p) + long_nvar_2mut*p for p in frac_long]

# lower viral load
extra_weeks = [0, 1, 2, 3, 4, 5, 10, 15]

means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/treated_valley_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=10))
    
to_plot = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,2]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,2]) for x in trans_probs]})
to_plot.sort_values(by="week", inplace=True)
prob_SEMs.sort_values(by="week", inplace=True)

long_nvar_2mut = np.sum(to_plot["any"][1:] * long_dist * p_surv*N)
short_nvar_2mut = to_plot["any"][0]

p = 0.001
combined_nvar.append(short_nvar_2mut*(1-p) + long_nvar_2mut*p)

#less transmitted
means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/valley_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=1))
    
to_plot = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,2]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,2]) for x in trans_probs]})
to_plot.sort_values(by="week", inplace=True)
prob_SEMs.sort_values(by="week", inplace=True)

long_nvar_2mut = np.sum(to_plot["any"][1:] * long_dist * p_surv*N)
short_nvar_2mut = to_plot["any"][0]

p = 0.001
combined_nvar.append(short_nvar_2mut*(1-p) + long_nvar_2mut*p)

#lower R0
means = []
lowers = []
uppers = []
trans_probs = []
for week in extra_weeks:
    data = pickle.load(open("./results/valley_long_"+str(week)+".p", "rb" ))
    mean, lower, upper = compute_fracs(data, log_data=False, CI=None, median=False)
    means.append(mean)
    lowers.append(lower)
    uppers.append(upper)
    trans_probs.append(transmit_probs(data, num_trans=8))
    
to_plot = pd.DataFrame({"week":extra_weeks, 'any':[np.mean(x[:,2]) for x in trans_probs]})
prob_SEMs = pd.DataFrame({"week":extra_weeks, 'any':[stats.sem(x[:,2]) for x in trans_probs]})
to_plot.sort_values(by="week", inplace=True)
prob_SEMs.sort_values(by="week", inplace=True)

N = 5e4
import scipy.optimize
lam = 1.05
def find_pext(x):
    return(1-np.exp(-lam*x) - x)

p_surv = scipy.optimize.broyden1(find_pext, 0.1, f_tol=1e-14)
long_nvar_2mut = np.sum(to_plot["any"][1:] * long_dist * p_surv*N)
short_nvar_2mut = to_plot["any"][0]

combined_nvar.append(short_nvar_2mut*(1-p) + long_nvar_2mut*p)

plt.figure(figsize=(9,6))
treatment_ticks = ["control", "fewer long\ninfections", "reduced\nviral load", "fewer virions\ntransmitted", "reduced\ntransmissibilty"]
plt.bar(treatment_ticks, combined_nvar, width=0.6, color=["lightgrey", "cornflowerblue", "cornflowerblue", "mistyrose", "tomato"])
for i in range(len(combined_nvar)):
    plt.text(treatment_ticks[i], combined_nvar[i]+0.05, str(round(combined_nvar[i],2)), fontsize=14, ha="center")

plt.ylabel("number of infections/day with\nnew two-mutation combination", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
sns.despine()
plt.tight_layout()
#plt.savefig(".eps", transparent=True)
plt.show()