In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from jax import random
plt.rcParams["figure.figsize"] = (10,8)

In [None]:
from rt_from_frequency_dynamics import discretise_gamma
from rt_from_frequency_dynamics import get_standard_delays
from rt_from_frequency_dynamics import FreeGrowthModel, FixedGrowthModel

from rt_from_frequency_dynamics import get_location_LineageData
from rt_from_frequency_dynamics import fit_SVI, MultiPosterior
from rt_from_frequency_dynamics import sample_loaded_posterior
from rt_from_frequency_dynamics import DefaultAes

# Load data

In [None]:
data_name = "variants-us"
raw_cases = pd.read_csv(f"../data/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

# Load US States (SVI)

In [None]:
locations =  ["Washington", "California", "New York", "Michigan", "Florida"]
optimizer = numpyro.optim.Adam(step_size=1e-2)
num_samples = 3000

In [None]:
def load_models(rc, rs, locations, model_type, path=".", num_samples=1000):
    g, delays = get_standard_delays()
    LM = model_type(g, delays, 7, 0)
    MP = MultiPosterior()
    for i, loc in enumerate(locations):
        LD = get_location_LineageData(rc, rs, loc)
        PH = sample_loaded_posterior(LD, LM, num_samples=num_samples, path=path, name=loc)   
        MP.add_posterior(PH)
        print(f"Location {loc} finished {i+1} / {len(locations)}")
    return MP

In [None]:
path_base = f"../estimates/{data_name}"

path_free = path_base + "/free"
path_fixed = path_base + "/fixed"

In [None]:
MP_free = load_models(raw_cases, raw_seq, locations, FreeGrowthModel, path=path_free, num_samples=1000)

In [None]:
MP_fixed = load_models(raw_cases, raw_seq, locations, FixedGrowthModel, path=path_fixed, num_samples=1000)

## Making plots

In [None]:
import matplotlib
import matplotlib.transforms as mtransforms

font = {'family' : 'Helvetica',
        'weight' : 'light',
        'size'   : 32}

matplotlib.rc('font', **font)

In [None]:
from rt_from_frequency_dynamics.plotfunctions import *
ps = DefaultAes.ps
alphas = DefaultAes.alphas
v_colors =["#2e5eaa", "#5adbff",  "#56e39f","#b4c5e4", "#f03a47",  "#f5bb00", "#9e4244", "#808080"] 
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'other']
color_map = {v : c for c, v in zip(v_colors, v_names)}

In [None]:
def unpack_model(MP, loc):
    posterior = MP.get(loc)
    return posterior.dataset, posterior.LD

## Plotting free Rt

In [None]:

def figure_free_rt(dataset, LD, ps, alphas, colors):
    fig = plt.figure(figsize=(30, 24))
    gs = fig.add_gridspec(nrows=4, ncols=6, height_ratios=[2.,1.5, 1.0,1.0])
    single_color = "#3A3B3C"
    # Top left
    ax1 = fig.add_subplot(gs[0, :3])
    plot_cases(ax1, LD)
    plot_posterior_smooth_EC(ax1, dataset, ps, alphas, single_color)
    ax1.set_ylabel("Posterior smoothed cases") 
    
    # Top right
    ax2 = fig.add_subplot(gs[0, 3:], sharey=ax1)
    plot_cases(ax2, LD)
    plot_posterior_I(ax2, dataset, ps, alphas, colors)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # middle left
    ax3 = fig.add_subplot(gs[1,:3], sharex=ax1)
    plot_posterior_average_R(ax3, dataset, ps, alphas, single_color)
    add_dates(ax3, LD.dates)
    ax3.set_ylabel(r"$R_{t}$") 

    
    # middle right
    ax4 = fig.add_subplot(gs[1, 3:], sharex=ax2, sharey=ax3)
    plot_R_censored(ax4, dataset, ps, alphas, colors, thres=0.001)
    add_dates(ax4, LD.dates)
    plt.setp(ax4.get_yticklabels(), visible=False)

    #  Bottom left
    ax5a = fig.add_subplot(gs[2, 0:2])
    plot_total_by_obs_frequency(ax5a, LD, LD.seq_counts.sum(axis=1), colors)

    ax5b = fig.add_subplot(gs[3, 0:2], sharex=ax5a)
    plot_total_by_obs_frequency(ax5b, LD, jnp.full(LD.cases.shape[-1], fill_value=1), colors)
    add_dates(ax5b, LD.dates, sep=2)

    # Bottom middle
    ax6 = fig.add_subplot(gs[2:, 2:4])
    plot_posterior_frequency(ax6, dataset, ps, alphas, colors)
    plot_observed_frequency_size(ax6, LD, colors, lambda n: 2.5*jnp.sqrt(n))
    add_dates(ax6, LD.dates, sep=2)
    ax6.set_ylabel("Posterior variant frequencies")
    
    # Bottom right
    ax7 = fig.add_subplot(gs[2:, 4:6], sharey=ax1)
    plot_total_by_median_frequency(ax7, dataset, LD, LD.cases, colors)
    add_dates(ax7, LD.dates, sep=2)
    ax7.set_ylabel("Median variant cases") 
 
    # Add labels
    axs = [ax1, ax2, ax5a]
    labels = ["(a)", "(b)", "(c)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-42/72, 14/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
            
    # Putting down color legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in zip(LD.seq_names, colors)]
    legend = fig.legend(patches, LD.seq_names, ncol=len(LD.seq_names), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1) 
    return fig

In [None]:
dataset, LD = unpack_model(MP_free, "Washington")
colors = [color_map[v] for v in LD.seq_names]

In [None]:
fig_1 = figure_free_rt(dataset, LD, ps, alphas, colors)

In [None]:
fig_free_rt_locs = ["Washington", "California", "New York", "Michigan", "Florida"]

In [None]:
for loc in fig_free_rt_locs:
    dataset, LD = unpack_model(MP_free, loc)
    colors = [color_map[v] for v in LD.seq_names]
    fig_fg_loc = figure_free_rt(dataset, LD, ps, alphas, colors)
    _loc = loc.replace(" ", "-")
    fig_fg_loc.savefig(f"../manuscript/figs/free_rt_{_loc}.png", facecolor="w", bbox_inches='tight')

## Plotting growth advantage

In [None]:
def figure_fixed_growth(dataset, LD, ps, alphas, colors):
    # Figure 2
    fig = plt.figure(figsize=(30, 20))
    gs = fig.add_gridspec(nrows=4, ncols=2) #, height_ratios=[2.,1.5, 1.0,1.0])
    single_color = "#3A3B3C"

    # Top left
    ax1 = fig.add_subplot(gs[:2,0])
    plot_cases(ax1, LD)
    plot_posterior_smooth_EC(ax1, dataset, ps, alphas, single_color)
    ax1.set_ylabel("Posterior smoothed cases")
    
    # Top right
    ax2 = fig.add_subplot(gs[:2,1], sharey=ax1)
    plot_cases(ax2, LD)
    plot_posterior_I(ax2, dataset, ps, alphas, colors)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # Bottom left
    ax3 = fig.add_subplot(gs[2:,0], sharex=ax1)
    plot_posterior_frequency(ax3, dataset, ps, alphas, colors)
    plot_observed_frequency_size(ax3, LD, colors, lambda n: 2.5*jnp.sqrt(n))
    add_dates(ax3, LD.dates)
    ax3.set_ylabel("Posterior lineage frequencies")

    # Bottom right 1
    ax4 = fig.add_subplot(gs[2,1], sharex=ax2)
    plot_R_censored(ax4, dataset, ps, alphas, colors, thres=0.005)
    add_dates(ax4, LD.dates)
    ax4.set_ylabel(r"$R_{t}$")

    # Bottom right 1
    ax5 = fig.add_subplot(gs[3,1])
    plot_growth_advantage(ax5, dataset, LD, ps, alphas, colors)
    ax5.set_ylabel("Growth Advantage")

    axs = [ax1, ax2, ax3, ax4, ax5]
    labels = ["(a)", "(b)", "(c)", "(d)", "(e)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-32/72, 8/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
    plt.tight_layout()
    
    # Putting down color legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in zip(LD.seq_names, colors)]
    legend = fig.legend(patches, LD.seq_names, ncol=len(LD.seq_names), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1) 
    return fig

In [None]:
dataset, LD = unpack_model(MP_fixed, "Washington")
colors = [color_map[v] for v in LD.seq_names]
fig_2 = figure_fixed_growth(dataset, LD, ps, alphas, colors)

In [None]:
fig_fixed_growth_locs = ["Washington", "California", "New York", "Michigan", "Florida"]

In [None]:
for loc in fig_fixed_growth_locs:
    dataset, LD = unpack_model(MP_fixed, loc)
    colors = [color_map[v] for v in LD.seq_names]
    fig_fg_loc = figure_fixed_growth(dataset, LD, ps, alphas, colors)
    fig_fg_loc.savefig(f"../manuscript/figs/fixed_growth_{loc}.png", facecolor="w", bbox_inches='tight')

## Figure: Growth advantages

In [None]:
ga_df = pd.read_csv(f"{path_base}/{data_name}_ga-combined-fixed.tsv", sep = "\t")

In [None]:
def figure_growth_advantage(ga_df, LD, ps, alphas, colors):
    fig = plt.figure(figsize=(28, 20))
  
    variants = pd.unique(ga_df.variant)
    locations = pd.unique(ga_df.location)
    location_map = {l: i for i, l in enumerate(locations)}

    
    # Sort level of confidence  
    _lw = [1.5, 2.5, 3.5]
    
    # Top panel
    ax1 = fig.add_subplot(2,1,2)
    ax1.axhline(y=1, lw=2,linestyle='dashed', color="k")

    for v, var in enumerate(variants):
        this_lineage = ga_df[ga_df.variant == var]
        location_num = this_lineage["location"].map(location_map)
        ax1.scatter(location_num, this_lineage.median_ga.values, 
                    color=colors[v],
                    edgecolors="k",
                    s = 45,
                    zorder = 3)
        
        # Plot error bars for each level of credibility
        for i, p in enumerate(ps):
            _p = int(p * 100)
            l_err = this_lineage.median_ga.values - this_lineage[f"ga_lower_{_p}"].values
            r_err = this_lineage[f"ga_upper_{_p}"].values - this_lineage.median_ga.values 
            ax1.errorbar(location_num, this_lineage.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = colors[v], elinewidth = _lw[i])
        
    
    # Adding state labels
    ax1.set_xticks(np.arange(0, len(locations), 1))
    ax1.set_xticklabels([l.replace("_", " ") for l in locations],  rotation =90)
    
    # Adding axis label
    ax1.set_ylabel("Growth Advantage")
    
    # Right plot
    ax2 = fig.add_subplot(2,1,1)
    ax2.axhline(y=1, lw=2, linestyle='dashed', color="k")

    violin_data = [ga_df[ga_df.variant == v].median_ga.values for v in variants]
    parts = ax2.violinplot(violin_data, 
                           showmeans=False, 
                           showmedians=False, 
                           showextrema=False)
    
    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors[i])
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    
    for v, var in enumerate(variants):
        this_lineage = ga_df[ga_df.variant == var]
        ax2.scatter([v+1 + np.random.normal(0, 0.02, 1) for i in range(len(this_lineage))],
                    this_lineage.median_ga.values, 
                    color=colors[v],
                    edgecolors="k",
                    s = 45,
                    zorder = 3) 
    
    ax2.set_ylabel("Median Growth Advantage")
    ax2.set_xticks(np.arange(1, len(variants)+1, 1))
    ax2.set_xticklabels(variants)
    
    axs = [ax2, ax1]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-32/72, 8/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
colors = [color_map[v] for v in pd.unique(ga_df.variant)]

In [None]:
fig_3 = figure_growth_advantage(ga_df, LD, ps, alphas, colors)

In [None]:
 fig_3.savefig("../manuscript/figs/growth_advantages.png", facecolor="w", bbox_inches='tight')

## Figure: Rt consensus

In [None]:
rt_df_free = pd.read_csv(f"{path_base}/{data_name}_Rt-combined-free.tsv", sep="\t")

In [None]:
def figure_rt_consensus(rt_df, LD, ps, alphas, colors, thres = 0.001):
    fig = plt.figure(figsize=(30, 15))
    variants = pd.unique(rt_df.variant)
    locations = pd.unique(rt_df.location)
    dates = pd.unique(rt_df.date)
    dates.sort()
    dates_map = {d : i for i, d in enumerate(dates)}
    
    n_rows = 2
    if len(variants) % n_rows == 0:
        n_cols = len(variants) // n_rows
    else:
        n_cols = 1 +  len(variants) // n_rows
    
    ax_list = []
    
    for v, var in enumerate(variants):
        if v == 0:
            ax = fig.add_subplot(n_rows,n_cols, v+1)
        else:
            ax = fig.add_subplot(n_rows,n_cols, v+1, sharey = ax_list[0])
            
        ax.axhline(y=1, lw=2, linestyle='dashed', color="k")    
        this_variant = rt_df[rt_df.variant == var].copy()
        
        for l, loc in enumerate(locations):
            this_loc = this_variant[this_variant.location == loc].copy()
            included = np.array(this_loc.median_freq.values >= thres)
            dates_num = this_loc["date"].map(dates_map)

            m = this_loc["median_R"].values
            ax.plot(dates_num[included], m[included], color = 'k', alpha = 0.1)
                        
            # Plot bands for each level of credibility
            for i, p in enumerate(ps):
                _p = int(p * 100)
                l = this_loc[f"R_lower_{_p}"].values
                r = this_loc[f"R_upper_{_p}"].values
                ax.fill_between(dates_num[included], l[included], r[included], 
                               color = colors[v], alpha=alphas[i])    
        
        # Add dates
        add_dates(ax, LD.dates, sep=2)
        
        if v % n_cols != 0:
            plt.setp(ax.get_yticklabels(), visible=False)
        else: 
            ax.set_ylabel("Effective Reproduction Number")

        ax.set_title(var)
        # Add to list
        ax_list.append(ax)
        plt.tight_layout()
        
    return fig

In [None]:
colors = [color_map[v] for v in pd.unique(rt_df_free.variant)]
fig_4 = figure_rt_consensus(rt_df_free, LD, ps, alphas, colors, thres = 0.001)

In [None]:
fig_4.savefig("../manuscript/figs/rt_consensus.png", facecolor="w", bbox_inches='tight')

## Figure: generation_time_sensitivity

In [None]:
rt_sens_m = pd.read_csv("../estimates/variants-us-sensitivity-means/variants-us-sensitivity-means_Rt-combined-free.tsv", sep = "\t")
rt_sense_sd = pd.read_csv("../estimates/variants-us-sensitivity-sd/variants-us-sensitivity-sd_Rt-combined-free.tsv", sep="\t")

In [None]:
def figure_gen_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()
    this_variant_m = this_variant_m[this_variant_m.date == date]
    
    this_variant_s = r_s[r_s.variant == variant].copy()
    this_variant_s = this_variant_s[this_variant_s.date == date]
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    ax_m.scatter(ms_values, 
                 this_variant_m.median_R.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_m.median_R.values - this_variant_m[f"R_lower_{_p}"].values
        r_err = this_variant_m[f"R_upper_{_p}"].values - this_variant_m.median_R.values 
        ax_m.errorbar(ms_values, this_variant_m.median_R.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Effective Reproduction Number ({date})")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting sd values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    ax_s.scatter(sd_values, 
                 this_variant_s.median_R.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_s.median_R.values - this_variant_s[f"R_lower_{_p}"].values
        r_err = this_variant_s[f"R_upper_{_p}"].values - this_variant_s.median_R.values 
        ax_s.errorbar(sd_values, this_variant_s.median_R.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    #ax_s.set_ylabel(f"Effective Reproduction Number ({date})")
    ax_s.set_xlabel(f"Standard deviation of generation time")
    
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 0/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_5 = figure_gen_sens(rt_sens_m, rt_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_5.savefig("../manuscript/figs/generation_time_sensitivity.png", facecolor="w", bbox_inches='tight')

## Figure: little_r_sensitivity

In [None]:
def R_gamma_to_sens(R, m, s):
    g = discretise_gamma(m, s)
    mn = np.sum([p * (x+1) for x, p in enumerate(g)]) # Get mean of discretized generation time
    sd = np.sqrt(np.sum([p * (x+1) **2 for x, p in enumerate(g)])-mn**2) # Get sd of discretized generation time
    e_ = sd**2 / mn**2
    l = mn / (sd**2)
    return (np.float_power(R, e_) - 1) * l

In [None]:
def figure_little_r_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()
    this_variant_m = this_variant_m[this_variant_m.date == date]
    
    this_variant_s = r_s[r_s.variant == variant].copy()
    this_variant_s = this_variant_s[this_variant_s.date == date]
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    med_r = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m.median_R.values)])
    
    ax_m.scatter(ms_values, 
                 med_r, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        lr = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m[f"R_lower_{_p}"].values)])
        l_err = med_r - lr
        
        ur = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m[f"R_upper_{_p}"].values)])
        r_err = ur - med_r 
        ax_m.errorbar(ms_values, med_r, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Exponential growth rate ({date})")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting sd values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    med_r = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s.median_R.values)])

    ax_s.scatter(sd_values, 
                 med_r, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        
        lr = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s[f"R_lower_{_p}"].values)])
        l_err = med_r - lr
        ur = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s[f"R_upper_{_p}"].values)])

        r_err = ur - med_r
        ax_s.errorbar(sd_values, med_r, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_s.set_xlabel(f"Standard deviation of generation time")
    
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 20/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_6 = figure_little_r_sens(rt_sens_m, rt_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_6.savefig("../manuscript/figs/little_r_sensitivity.png", facecolor="w", bbox_inches='tight')

## Figure: growth_advantage_sensitivity

In [None]:
ga_sens_m = pd.read_csv("../estimates/variants-us-sensitivity-means/variants-us-sensitivity-means_ga-combined-fixed.tsv", sep = "\t")
ga_sense_sd = pd.read_csv("../estimates/variants-us-sensitivity-sd/variants-us-sensitivity-sd_ga-combined-fixed.tsv", sep="\t")

In [None]:
def figure_ga_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()    
    this_variant_s = r_s[r_s.variant == variant].copy()
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    ax_m.scatter(ms_values, 
                 this_variant_m.median_ga.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_m.median_ga.values - this_variant_m[f"ga_lower_{_p}"].values
        r_err = this_variant_m[f"ga_upper_{_p}"].values - this_variant_m.median_ga.values 
        ax_m.errorbar(ms_values, this_variant_m.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Growth advantage")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting ms values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    ax_s.scatter(sd_values, 
                 this_variant_s.median_ga.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_s.median_ga.values - this_variant_s[f"ga_lower_{_p}"].values
        r_err = this_variant_s[f"ga_upper_{_p}"].values - this_variant_s.median_ga.values 
        ax_s.errorbar(sd_values, this_variant_s.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_s.set_xlabel(f"Standard deviation of generation time")
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 0/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_7 = figure_ga_sens(ga_sens_m, ga_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_7.savefig("../manuscript/figs/growth_advantage_sensitivity.png", facecolor="w", bbox_inches='tight')