In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from jax import random
plt.rcParams["figure.figsize"] = (10,8)

In [None]:
import rt_from_frequency_dynamics as rf

# Load data

In [None]:
data_name = "variants-us"
raw_cases = pd.read_csv(f"../data/{data_name}/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

# Load US States (SVI)

In [None]:
locations =  ["Washington", "California", "New York", "Michigan", "Florida"]
optimizer = numpyro.optim.Adam(step_size=1e-3)
num_samples = 3000

In [None]:
seed_L = 14
forecast_L = 0

# Get delays
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'Omicron', 'other']

gen = rf.pad_delays(
    [rf.discretise_gamma(mn=4.4, std=1.2), # Alpha
     rf.discretise_gamma(mn=4.4, std=1.2), # Beta
     rf.discretise_gamma(mn=4.4, std=1.2), # Delta
     rf.discretise_gamma(mn=4.4, std=1.2), # Epsilon
     rf.discretise_gamma(mn=4.4, std=1.2), # Gamma
     rf.discretise_gamma(mn=4.4, std=1.2), # Iota
     rf.discretise_gamma(mn=4.4, std=1.2), # Mu
     rf.discretise_gamma(mn=3.1, std=1.2), # Omicron
     rf.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = rf.pad_delays([rf.discretise_lognorm(mn=3.1, std=1.0)])

k = 25 # Number of spline basis elements

# Pick likelihoods
## R Likelihoods
GARW = rf.GARW(0.05, 0.05)
FreeGrowth = rf.FreeGrowth()
FGA = rf.FixedGA()

CLik = rf.ZINegBinomCases(0.02) # Case likelihood
SLik = rf.DirMultinomialSeq(100) # Sequence count likelihood

LM_fixed = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=k, RLik = FGA, CLik = CLik, SLik = SLik,  v_names = v_names)
LM_GARW = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=k, RLik = GARW, CLik = CLik, SLik = SLik,  v_names = v_names)

In [None]:
# Loading past results
def load_models(rc, rs, locations, RM, path=".", num_samples=1000):
    MP = rf.MultiPosterior()
    for i, loc in enumerate(locations):
        LD = rf.get_location_VariantData(rc, rs, loc)
        PH = rf.sample_loaded_posterior(LD, RM, num_samples=num_samples, path=path, name=loc)   
        MP.add_posterior(PH)
        print(f"Location {loc} finished {i+1} / {len(locations)}")
    return MP

In [None]:
path_base = f"../estimates/{data_name}"

path_free = path_base + "/free"
path_fixed = path_base + "/fixed"
path_GARW = path_base + "/GARW"

In [None]:
#MP_free = load_models(raw_cases, raw_seq, locations, rf.FreeGrowthModel, path=path_free, num_samples=1000)

In [None]:
MP_fixed = load_models(raw_cases, raw_seq, locations, LM_fixed, path=path_fixed, num_samples=1000)

In [None]:
MP_GARW = load_models(raw_cases, raw_seq, locations, LM_GARW, path=path_GARW, num_samples=1500)

## Making plots

In [None]:
import matplotlib
import matplotlib.transforms as mtransforms

font = {'family' : 'Helvetica',
        'weight' : 'light',
        'size'   : 32}

matplotlib.rc('font', **font)

In [None]:
from rt_from_frequency_dynamics.plotfunctions import *
ps = DefaultAes.ps
alphas = DefaultAes.alphas
v_colors =["#2e5eaa", "#5adbff",  "#56e39f","#b4c5e4", "#f03a47",  "#f5bb00", "#9e4244","#9932CC", "#808080"] 
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'Omicron', 'other']
color_map = {v : c for c, v in zip(v_colors, v_names)}

In [None]:
def unpack_model(MP, loc):
    posterior = MP.get(loc)
    return posterior.dataset, posterior.data

## Plotting free Rt

In [None]:

def figure_free_rt(dataset, LD, ps, alphas, colors):
    fig = plt.figure(figsize=(30, 24))
    gs = fig.add_gridspec(nrows=4, ncols=6, height_ratios=[2.,1.5, 1.0,1.0])
    single_color = "#3A3B3C"
    # Top left
    ax1 = fig.add_subplot(gs[0, :3])
    rf.plot_cases(ax1, LD)
    rf.plot_posterior_smooth_EC(ax1, dataset, ps, alphas, single_color)
    ax1.set_ylabel("Posterior smoothed cases") 
    
    # Top right
    ax2 = fig.add_subplot(gs[0, 3:], sharey=ax1)
    rf.plot_cases(ax2, LD)
    rf.plot_posterior_I(ax2, dataset, ps, alphas, colors)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # middle left
    ax3 = fig.add_subplot(gs[1,:3], sharex=ax1)
    rf.plot_posterior_average_R(ax3, dataset, ps, alphas, single_color)
    rf.add_dates(ax3, LD.dates)
    ax3.set_ylabel(r"$R_{t}$") 

    
    # middle right
    ax4 = fig.add_subplot(gs[1, 3:], sharex=ax2, sharey=ax3)
    rf.plot_R_censored(ax4, dataset, ps, alphas, colors, thres=0.001)
    rf.add_dates(ax4, LD.dates)
    plt.setp(ax4.get_yticklabels(), visible=False)

    #  Bottom left
    ax5a = fig.add_subplot(gs[2, 0:2])
    rf.plot_total_by_obs_frequency(ax5a, LD, LD.seq_counts.sum(axis=1), colors)

    ax5b = fig.add_subplot(gs[3, 0:2], sharex=ax5a)
    rf.plot_total_by_obs_frequency(ax5b, LD, jnp.full(LD.cases.shape[-1], fill_value=1), colors)
    rf.add_dates(ax5b, LD.dates, sep=2)

    # Bottom middle
    ax6 = fig.add_subplot(gs[2:, 2:4])
    rf.plot_posterior_frequency(ax6, dataset, ps, alphas, colors)
    rf.plot_observed_frequency_size(ax6, LD, colors, lambda n: 2.5*jnp.sqrt(n))
    rf.add_dates(ax6, LD.dates, sep=2)
    ax6.set_ylabel("Posterior variant frequencies")
    
    # Bottom right
    ax7 = fig.add_subplot(gs[2:, 4:6], sharey=ax1)
    rf.plot_total_by_median_frequency(ax7, dataset, LD, LD.cases, colors)
    rf.add_dates(ax7, LD.dates, sep=2)
    ax7.set_ylabel("Median variant cases") 
 
    # Add labels
    axs = [ax1, ax2, ax5a]
    labels = ["(a)", "(b)", "(c)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-42/72, 14/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
            
    # Putting down color legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in zip(LD.seq_names, colors)]
    legend = fig.legend(patches, LD.seq_names, ncol=len(LD.seq_names), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1) 
    return fig

In [None]:
dataset, LD = unpack_model(MP_GARW, "Washington")
colors = [color_map[v] for v in LD.seq_names]

In [None]:
fig_1 = figure_free_rt(dataset, LD, ps, alphas, colors)

In [None]:
fig_free_rt_locs = ["Washington", "California", "New York", "Michigan", "Florida"]

In [None]:
for loc in fig_free_rt_locs:
    dataset, LD = unpack_model(MP_GARW, loc)
    colors = [color_map[v] for v in LD.seq_names]
    fig_fg_loc = figure_free_rt(dataset, LD, ps, alphas, colors)
    _loc = loc.replace(" ", "-")
    fig_fg_loc.savefig(f"../manuscript/figs/GARW_rt_{_loc}.png", facecolor="w", bbox_inches='tight')

## Plotting growth advantage

In [None]:
def figure_fixed_growth(dataset, LD, ps, alphas, colors):
    # Figure 2
    fig = plt.figure(figsize=(30, 20))
    gs = fig.add_gridspec(nrows=4, ncols=2) #, height_ratios=[2.,1.5, 1.0,1.0])
    single_color = "#3A3B3C"

    # Top left
    ax1 = fig.add_subplot(gs[:2,0])
    plot_cases(ax1, LD)
    plot_posterior_smooth_EC(ax1, dataset, ps, alphas, single_color)
    ax1.set_ylabel("Posterior smoothed cases")
    
    # Top right
    ax2 = fig.add_subplot(gs[:2,1], sharey=ax1)
    plot_cases(ax2, LD)
    plot_posterior_I(ax2, dataset, ps, alphas, colors)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # Bottom left
    ax3 = fig.add_subplot(gs[2:,0], sharex=ax1)
    plot_posterior_frequency(ax3, dataset, ps, alphas, colors)
    plot_observed_frequency_size(ax3, LD, colors, lambda n: 2.5*jnp.sqrt(n))
    add_dates(ax3, LD.dates)
    ax3.set_ylabel("Posterior lineage frequencies")

    # Bottom right 1
    ax4 = fig.add_subplot(gs[2,1], sharex=ax2)
    plot_R_censored(ax4, dataset, ps, alphas, colors, thres=0.005)
    add_dates(ax4, LD.dates)
    ax4.set_ylabel(r"$R_{t}$")

    # Bottom right 1
    ax5 = fig.add_subplot(gs[3,1])
    plot_growth_advantage(ax5, dataset, LD, ps, alphas, colors)
    ax5.set_ylabel("Growth Advantage")

    axs = [ax1, ax2, ax3, ax4, ax5]
    labels = ["(a)", "(b)", "(c)", "(d)", "(e)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-32/72, 8/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
    plt.tight_layout()
    
    # Putting down color legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in zip(LD.seq_names, colors)]
    legend = fig.legend(patches, LD.seq_names, ncol=len(LD.seq_names), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1) 
    return fig

In [None]:
dataset, LD = unpack_model(MP_fixed, "Washington")
colors = [color_map[v] for v in LD.seq_names]
fig_2 = figure_fixed_growth(dataset, LD, ps, alphas, colors)

In [None]:
fig_fixed_growth_locs = ["Washington", "California", "New York", "Michigan", "Florida"]

In [None]:
for loc in fig_fixed_growth_locs:
    dataset, LD = unpack_model(MP_fixed, loc)
    colors = [color_map[v] for v in LD.seq_names]
    fig_fg_loc = figure_fixed_growth(dataset, LD, ps, alphas, colors)
    fig_fg_loc.savefig(f"../manuscript/figs/fixed_growth_{loc}.png", facecolor="w", bbox_inches='tight')

## Figure: Growth advantages

In [None]:
ga_df = pd.read_csv(f"{path_base}/{data_name}_ga-combined-fixed.tsv", sep = "\t")

In [None]:
def figure_growth_advantage(ga_df, ps, alphas, colors):
    fig = plt.figure(figsize=(28, 20))
  
    variants = pd.unique(ga_df.variant)
    locations = pd.unique(ga_df.location)
    location_map = {l: i for i, l in enumerate(locations)}

    
    # Sort level of confidence  
    _lw = [1.5, 2.5, 3.5]
    
    # Top panel
    ax1 = fig.add_subplot(2,1,2)
    ax1.axhline(y=1, lw=2,linestyle='dashed', color="k")

    for v, var in enumerate(variants):
        this_lineage = ga_df[ga_df.variant == var]
        location_num = this_lineage["location"].map(location_map)
        ax1.scatter(location_num, this_lineage.median_ga.values, 
                    color=colors[v],
                    edgecolors="k",
                    s = 45,
                    zorder = 3)
        
        # Plot error bars for each level of credibility
        for i, p in enumerate(ps):
            _p = int(p * 100)
            l_err = this_lineage.median_ga.values - this_lineage[f"ga_lower_{_p}"].values
            r_err = this_lineage[f"ga_upper_{_p}"].values - this_lineage.median_ga.values 
            ax1.errorbar(location_num, this_lineage.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = colors[v], elinewidth = _lw[i])
        
    
    # Adding state labels
    ax1.set_xticks(np.arange(0, len(locations), 1))
    ax1.set_xticklabels([l.replace("_", " ") for l in locations],  rotation =90)
    
    # Adding axis label
    ax1.set_ylabel("Growth Advantage")
    
    # Right plot
    ax2 = fig.add_subplot(2,1,1)
    ax2.axhline(y=1, lw=2, linestyle='dashed', color="k")

    violin_data = [ga_df[ga_df.variant == v].median_ga.values for v in variants]
    parts = ax2.violinplot(violin_data, 
                           showmeans=False, 
                           showmedians=False, 
                           showextrema=False)
    
    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors[i])
        pc.set_edgecolor('black')
        pc.set_alpha(1)
    
    for v, var in enumerate(variants):
        this_lineage = ga_df[ga_df.variant == var]
        ax2.scatter([v+1 + np.random.normal(0, 0.02, 1) for i in range(len(this_lineage))],
                    this_lineage.median_ga.values, 
                    color=colors[v],
                    edgecolors="k",
                    s = 45,
                    zorder = 3) 
    
    ax2.set_ylabel("Median Growth Advantage")
    ax2.set_xticks(np.arange(1, len(variants)+1, 1))
    ax2.set_xticklabels(variants)
    
    axs = [ax2, ax1]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-32/72, 8/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
colors = [color_map[v] for v in pd.unique(ga_df.variant)]

In [None]:
fig_3 = figure_growth_advantage(ga_df, ps, alphas, colors)

In [None]:
 fig_3.savefig("../manuscript/figs/growth_advantages.png", facecolor="w", bbox_inches='tight')

## Figure: Rt consensus

In [None]:
rt_df_GARW = pd.read_csv(f"{path_base}/{data_name}_Rt-combined-GARW.tsv", sep="\t")

In [None]:
def figure_rt_consensus(rt_df, LD, ps, alphas, colors, thres = 0.001):
    fig = plt.figure(figsize=(25, 15))
    variants = pd.unique(rt_df.variant)
    locations = pd.unique(rt_df.location)
    dates = pd.unique(rt_df.date)
    dates.sort()
    dates_map = {d : i for i, d in enumerate(dates)}
    
    n_rows = 3
    if len(variants) % n_rows == 0:
        n_cols = len(variants) // n_rows
    else:
        n_cols = 1 +  len(variants) // n_rows
    
    ax_list = []
    
    for v, var in enumerate(variants):
        if v == 0:
            ax = fig.add_subplot(n_rows,n_cols, v+1)
        else:
            ax = fig.add_subplot(n_rows,n_cols, v+1, sharey = ax_list[0])
            
        ax.axhline(y=1, lw=2, linestyle='dashed', color="k")    
        this_variant = rt_df[rt_df.variant == var].copy()
        
        for l, loc in enumerate(locations):
            this_loc = this_variant[this_variant.location == loc].copy()
            included = np.array(this_loc.median_freq.values >= thres)
            dates_num = this_loc["date"].map(dates_map)

            m = this_loc["median_R"].values
            ax.plot(dates_num[included], m[included], color = 'k', alpha = 0.1)
                        
            # Plot bands for each level of credibility
            for i, p in enumerate(ps):
                _p = int(p * 100)
                l = this_loc[f"R_lower_{_p}"].values
                r = this_loc[f"R_upper_{_p}"].values
                ax.fill_between(dates_num[included], l[included], r[included], 
                               color = colors[v], alpha=alphas[i])    
        
        # Add dates
        add_dates(ax, LD.dates, sep=2)
        
        if v % n_cols != 0:
            plt.setp(ax.get_yticklabels(), visible=False)
        else: 
            ax.set_ylabel("Effective Reproduction Number")

        ax.set_title(var)
        ax.set_ylim(0.0, 5.0)
        # Add to list
        ax_list.append(ax)
        plt.tight_layout()
        
    return fig

In [None]:
colors = [color_map[v] for v in pd.unique(rt_df_free.variant)]
fig_4 = figure_rt_consensus(rt_df_GARW, LD, ps, alphas, colors, thres = 0.01)

In [None]:
fig_4.savefig("../manuscript/figs/rt_consensus.png", facecolor="w", bbox_inches='tight')

## Figure: generation_time_sensitivity

In [None]:
rt_sens_m = pd.read_csv("../estimates/variants-us-sensitivity-means/variants-us-sensitivity-means_Rt-combined-free.tsv", sep = "\t")
rt_sense_sd = pd.read_csv("../estimates/variants-us-sensitivity-sd/variants-us-sensitivity-sd_Rt-combined-free.tsv", sep="\t")

In [None]:
def figure_gen_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()
    this_variant_m = this_variant_m[this_variant_m.date == date]
    
    this_variant_s = r_s[r_s.variant == variant].copy()
    this_variant_s = this_variant_s[this_variant_s.date == date]
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    ax_m.scatter(ms_values, 
                 this_variant_m.median_R.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_m.median_R.values - this_variant_m[f"R_lower_{_p}"].values
        r_err = this_variant_m[f"R_upper_{_p}"].values - this_variant_m.median_R.values 
        ax_m.errorbar(ms_values, this_variant_m.median_R.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Effective Reproduction Number ({date})")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting sd values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    ax_s.scatter(sd_values, 
                 this_variant_s.median_R.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_s.median_R.values - this_variant_s[f"R_lower_{_p}"].values
        r_err = this_variant_s[f"R_upper_{_p}"].values - this_variant_s.median_R.values 
        ax_s.errorbar(sd_values, this_variant_s.median_R.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    #ax_s.set_ylabel(f"Effective Reproduction Number ({date})")
    ax_s.set_xlabel(f"Standard deviation of generation time")
    
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 0/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_5 = figure_gen_sens(rt_sens_m, rt_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_5.savefig("../manuscript/figs/generation_time_sensitivity.png", facecolor="w", bbox_inches='tight')

## Figure: little_r_sensitivity

In [None]:
def R_gamma_to_sens(R, m, s):
    g = discretise_gamma(m, s)
    mn = np.sum([p * (x+1) for x, p in enumerate(g)]) # Get mean of discretized generation time
    sd = np.sqrt(np.sum([p * (x+1) **2 for x, p in enumerate(g)])-mn**2) # Get sd of discretized generation time
    e_ = sd**2 / mn**2
    l = mn / (sd**2)
    return (np.float_power(R, e_) - 1) * l

In [None]:
def figure_little_r_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()
    this_variant_m = this_variant_m[this_variant_m.date == date]
    
    this_variant_s = r_s[r_s.variant == variant].copy()
    this_variant_s = this_variant_s[this_variant_s.date == date]
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    med_r = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m.median_R.values)])
    
    ax_m.scatter(ms_values, 
                 med_r, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        lr = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m[f"R_lower_{_p}"].values)])
        l_err = med_r - lr
        
        ur = np.array([R_gamma_to_sens(R, m, 1.72) for m,R in zip(ms_values, this_variant_m[f"R_upper_{_p}"].values)])
        r_err = ur - med_r 
        ax_m.errorbar(ms_values, med_r, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Exponential growth rate ({date})")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting sd values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    med_r = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s.median_R.values)])

    ax_s.scatter(sd_values, 
                 med_r, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        
        lr = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s[f"R_lower_{_p}"].values)])
        l_err = med_r - lr
        ur = np.array([R_gamma_to_sens(R, 5.2, s) for s,R in zip(sd_values, this_variant_s[f"R_upper_{_p}"].values)])

        r_err = ur - med_r
        ax_s.errorbar(sd_values, med_r, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_s.set_xlabel(f"Standard deviation of generation time")
    
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 20/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_6 = figure_little_r_sens(rt_sens_m, rt_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_6.savefig("../manuscript/figs/little_r_sensitivity.png", facecolor="w", bbox_inches='tight')

## Figure: growth_advantage_sensitivity

In [None]:
ga_sens_m = pd.read_csv("../estimates/variants-us-sensitivity-means/variants-us-sensitivity-means_ga-combined-fixed.tsv", sep = "\t")
ga_sense_sd = pd.read_csv("../estimates/variants-us-sensitivity-sd/variants-us-sensitivity-sd_ga-combined-fixed.tsv", sep="\t")

In [None]:
def figure_ga_sens(r_m, r_s, ps, alphas, colors):
    # Sort level of confidence  
    _lw = [3, 5, 7]
    
    ms_raw = pd.unique(r_m.location)
    ms = [float(m.replace("g_mean_", "")) for m in ms_raw]
    
    sds_raw = pd.unique(r_s.location)
    sds = [float(s.replace("g_sd_", "")) for s in sds_raw]
    
    ms_map = {n:m for n,m in zip(ms_raw, ms)}
    sds_map = {n:s for n,s in zip(sds_raw, sds)}
    
    
    # Set plot color
    variant = "Delta"
    variants = pd.unique(rt_sens_m.variant)
    for i,v in enumerate(variants):
        if v == variant:
            color = colors[i]
    
    # Simplify likelihood to given data
    date = "2021-07-01"
    this_variant_m = r_m[r_m.variant == variant].copy()    
    this_variant_s = r_s[r_s.variant == variant].copy()
    
    # Making figure
    fig = plt.figure(figsize=(28, 14))
    ax_m = fig.add_subplot(1,2,1)

    # Plotting ms values
    ms_values = [ms_map[m] for m in this_variant_m.location.values]
    ax_m.scatter(ms_values, 
                 this_variant_m.median_ga.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_m.median_ga.values - this_variant_m[f"ga_lower_{_p}"].values
        r_err = this_variant_m[f"ga_upper_{_p}"].values - this_variant_m.median_ga.values 
        ax_m.errorbar(ms_values, this_variant_m.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_m.set_ylabel(f"Growth advantage")
    ax_m.set_xlabel(f"Mean of generation time")
    
    ax_s = fig.add_subplot(1,2,2, sharey=ax_m)

    # Plotting ms values
    sd_values = [sds_map[m] for m in this_variant_s.location.values]
    ax_s.scatter(sd_values, 
                 this_variant_s.median_ga.values, 
                 edgecolors="k",
                    s = 160,
                    zorder = 3,
                    color = color)
    
    for i, p in enumerate(ps):
        _p = int(p * 100)
        l_err = this_variant_s.median_ga.values - this_variant_s[f"ga_lower_{_p}"].values
        r_err = this_variant_s[f"ga_upper_{_p}"].values - this_variant_s.median_ga.values 
        ax_s.errorbar(sd_values, this_variant_s.median_ga.values, 
                         yerr=[l_err, r_err], 
                          fmt = 'none',
                         color = color, 
                      elinewidth = _lw[i])
    
    ax_s.set_xlabel(f"Standard deviation of generation time")
    axs = [ax_m, ax_s]
    labels = ["(a)", "(b)"]
    
    for label, ax in zip(labels, axs):
        trans = mtransforms.ScaledTranslation(-72/72, 0/72, fig.dpi_scale_trans)
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom', fontfamily='serif')
        
    return fig

In [None]:
fig_7 = figure_ga_sens(ga_sens_m, ga_sense_sd, ps, alphas, lineage_colors)

In [None]:
fig_7.savefig("../manuscript/figs/growth_advantage_sensitivity.png", facecolor="w", bbox_inches='tight')

## Fig: Relative Wave Sizes

In [None]:
# Computing waves sizes

I_GARW = pd.read_csv(f"{path_base}/{data_name}_I-combined-GARW.tsv", sep="\t")

In [None]:
state_pop = pd.read_csv("../../../Downloads/state-population-sizes.tsv", sep="\t", header=None)
state_pop = state_pop.rename(columns={0: "location", 1: "pop_size"})

In [None]:
summed_posterior_I = (
    I_GARW
    .groupby(["location", "variant"])["median_I"] 
    .sum()
    .reset_index()
)

summed_posterior_I = pd.merge(summed_posterior_I, state_pop, on="location")

In [None]:
summed_posterior_I["frac_I"] = summed_posterior_I["median_I"] / summed_posterior_I["pop_size"]

In [None]:
abv_to_state = {
        'AK': 'Alaska',
        'AL': 'Alabama',
        'AR': 'Arkansas',
        'AS': 'American Samoa',
        'AZ': 'Arizona',
        'CA': 'California',
        'CO': 'Colorado',
        'CT': 'Connecticut',
        'DC': 'District of Columbia',
        'DE': 'Delaware',
        'FL': 'Florida',
        'GA': 'Georgia',
        'GU': 'Guam',
        'HI': 'Hawaii',
        'IA': 'Iowa',
        'ID': 'Idaho',
        'IL': 'Illinois',
        'IN': 'Indiana',
        'KS': 'Kansas',
        'KY': 'Kentucky',
        'LA': 'Louisiana',
        'MA': 'Massachusetts',
        'MD': 'Maryland',
        'ME': 'Maine',
        'MI': 'Michigan',
        'MN': 'Minnesota',
        'MO': 'Missouri',
        'MP': 'Northern Mariana Islands',
        'MS': 'Mississippi',
        'MT': 'Montana',
        'NA': 'National',
        'NC': 'North Carolina',
        'ND': 'North Dakota',
        'NE': 'Nebraska',
        'NH': 'New Hampshire',
        'NJ': 'New Jersey',
        'NM': 'New Mexico',
        'NV': 'Nevada',
        'NY': 'New York',
        'OH': 'Ohio',
        'OK': 'Oklahoma',
        'OR': 'Oregon',
        'PA': 'Pennsylvania',
        'PR': 'Puerto Rico',
        'RI': 'Rhode Island',
        'SC': 'South Carolina',
        'SD': 'South Dakota',
        'TN': 'Tennessee',
        'TX': 'Texas',
        'UT': 'Utah',
        'VA': 'Virginia',
        'VI': 'Virgin Islands',
        'VT': 'Vermont',
        'WA': 'Washington',
        'WI': 'Wisconsin',
        'WV': 'West Virginia',
        'WY': 'Wyoming'
}

state_to_abv = {v: k for k, v in abv_to_state.items()}

In [None]:
def relative_wave_sizes(I_df, colors, callouts):
    fig = plt.figure(figsize=(18, 18))
    gs = fig.add_gridspec(nrows=2, ncols=2, height_ratios=[0.8, 1.0])
    
    # Unpacking data needed
    locations = pd.unique(I_df["location"])
    variants = ["Alpha", "Delta", "Omicron"]
    
    # Comparing relative wave sizes across variants
    ax = fig.add_subplot(gs[0, 0:2])
    
    # Loop over locations
    t = np.arange(len(variants))
    for i, loc in enumerate(locations):
        loc_df = I_df[I_df.location == loc]
        var_df = loc_df[loc_df.variant.isin(variants)]
        ax.plot(t, var_df.frac_I, color = "k", alpha = 0.2)
        ax.scatter(t, var_df.frac_I,  
                   color = colors,
                   edgecolors="k",
                   s = 60,
                   zorder = 3)
    ax.set_xticks(np.arange(0, len(variants), 1))
    ax.set_xticklabels(variants) 
    ax.set_ylabel("Relative wave size")
        
    # Comparing first two
    ax = fig.add_subplot(gs[1,0])
    var_df_x =  I_df[I_df.variant == variants[0]]
    var_df_y =  I_df[I_df.variant == variants[1]]
    ax.scatter(var_df_x.frac_I, var_df_y.frac_I,
               edgecolors="k",
               color="orange",
               s = 45,
               zorder = 3)
    
    ax.set_xlabel(f"{variants[0]} wave size")
    ax.set_ylabel(f"{variants[1]} wave size")
    
    # Adding text for certain states
    for i, loc in enumerate(callouts):
        var_df_x_text = var_df_x[var_df_x.location == loc]
        var_df_y_text = var_df_y[var_df_y.location == loc]
        ax.annotate(state_to_abv[loc], (var_df_x_text.frac_I, var_df_y_text.frac_I),
                   size=20)
        
    # Comparing last two
    ax = fig.add_subplot(gs[1,1])
    var_df_x =  I_df[I_df.variant == variants[-2]]
    var_df_y =  I_df[I_df.variant == variants[-1]]
    ax.scatter(var_df_x.frac_I, var_df_y.frac_I,
               edgecolors="k",
               color="pink",
               s = 45,
               zorder = 3)
    
    ax.set_xlabel(f"{variants[-2]} wave size")
    ax.set_ylabel(f"{variants[-1]} wave size")
    
    # Adding text for certain states
    for i, loc in enumerate(callouts):
        var_df_x_text = var_df_x[var_df_x.location == loc]
        var_df_y_text = var_df_y[var_df_y.location == loc]
        ax.annotate(state_to_abv[loc], (var_df_x_text.frac_I, var_df_y_text.frac_I),
                   size=20)
        
    fig.tight_layout()


In [None]:
VoI = ["Alpha", "Delta", "Omicron"]
colors = [color_map[v] for v in VoI]
callouts = ["Washington", "California", "New York", "Michigan", "Florida"]

In [None]:
relative_wave_sizes(summed_posterior_I, colors, callouts=callouts)
#TODO: Couple of call-outs to orient folks
#TODO: Regression line and R^2?

# Fig: Correlation of initial R and wave size?

- Pull out earlier day for R where frequency greater than threshold
- Get this R value
- Plot against wave size for that variant
- Columns: location, variant, R..., frequency, wave_size

In [None]:
thres = 0.05
R_GARW = pd.read_csv(f"{path_base}/{data_name}_Rt-combined-GARW.tsv", sep="\t")
R_GARW["date"] = pd.to_datetime(R_GARW["date"])

In [None]:
# Getting minimal R
R_thres= R_GARW[R_GARW.median_freq > thres]
R_thres=R_thres.loc[R_thres.groupby(["location", "variant"]).date.idxmin()]

# Getting total wave size
R_wave_size = pd.merge(R_thres, summed_posterior_I, on=["location", "variant"])

In [None]:
R_wave_size_plot = R_wave_size[~R_wave_size.variant.isin(["other", "Omicron"])]

In [None]:
 pd.unique(R_wave_size["variant"])

In [None]:
def initial_R_wave_size(df,colors):
    fig = plt.figure(figsize=(10, 10))
    gs = fig.add_gridspec(nrows=1, ncols=1)
    
    variants = pd.unique(df["variant"])
    
    # Comparing relative wave sizes across variants
    ax = fig.add_subplot(gs[0, 0])
    
    # Loop over variants
    for i, var in enumerate(variants):
        var_df = df[df.variant == var]
        ax.scatter(var_df.median_R, var_df.frac_I,  
                   color = colors[i],
                   edgecolors="k",
                   s = 60,
                   zorder = 3)

    ax.set_ylabel("Relative wave size")
    ax.set_xlabel("R at introduction")    
        
    fig.tight_layout()    

In [None]:
colors = [color_map[v] for v in pd.unique(R_wave_size["variant"])]


In [None]:
colors

In [None]:
initial_R_wave_size(R_wave_size, colors)

# Fig: Correlation with vaccination with wave size

- Vaccination -> Wave size

- Pull out introduction date
- Get vaccination proportion on that date
- Plot Vaccination at introductiona against wave size
- Basically, want to look at lines by variant

In [None]:
#vaccination = pd.read_csv("regression_analysis_df.tsv", sep="\t")[["date", "location", "variant", "Series_Complete_Pop_Pct"]]
#vaccination["date"] = pd.to_datetime(vaccination["date"])

raw_vaccination = pd.read_csv("../../../Downloads/CDC-Vaccination.csv")
keep_cols = ["Date", "Location", "Series_Complete_Pop_Pct"]
vaccination = raw_vaccination[keep_cols]
vaccination = vaccination.rename(columns={"Date": "date", "Location":"location"})
vaccination["date"]=pd.to_datetime(vaccination.date)
vaccination = vaccination.sort_values("date")
vaccination = vaccination.replace({"location":abv_to_state})
vaccination["Series_Complete_Pop_Pct"] = vaccination["Series_Complete_Pop_Pct"] / 100

In [None]:
vaccination

In [None]:
wave_size_vaccination = pd.merge(
    R_wave_size[["date", "location", "variant", "frac_I"]], 
    vaccination, on=["date", "location"]
)

In [None]:
wave_size_vaccination

In [None]:
def vaccination_wave_size(df, colors):
    fig = plt.figure(figsize=(10, 10))
    gs = fig.add_gridspec(nrows=1, ncols=1)
    
    # Unpacking data needed
    #locations = pd.unique(I_df["location"])
    
    variants = pd.unique(df["variant"])
    
    # Comparing relative wave sizes across variants
    ax = fig.add_subplot(gs[0, 0])
    
    # Loop over variants
    for i, var in enumerate(variants):
        var_df = df[df.variant == var]
        #ax.plot(t, var_df.frac_I, color = "k", alpha = 0.2)
        ax.scatter(var_df.Series_Complete_Pop_Pct, var_df.frac_I,  
                   color = colors[i],
                   edgecolors="k",
                   s = 60,
                   zorder = 3)

    ax.set_ylabel("Relative wave size")
    ax.set_xlabel("Proportion vaccinated at introduction")    
        
    fig.tight_layout()

In [None]:
colors = [color_map[v] for v in pd.unique(wave_size_vaccination["variant"])]


In [None]:
vaccination_wave_size(wave_size_vaccination, colors)

# Possibly, looking at BA.2.