In [None]:
import numpyro
import numpy as np
import pandas as pd
import jax.numpy as jnp

import rt_from_frequency_dynamics as rf

In [None]:
data_name = "variants-us"
raw_cases = pd.read_csv(f"../data/{data_name}/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

In [None]:
# Locations to run
locations = pd.unique(raw_seq["location"])

In [None]:
(raw_seq
 [raw_seq.variant=="Omicron"] 
 .groupby("location")["date"] 
 .min()) # First date of Omicron observation

In [None]:
(raw_seq
 [raw_seq.location=="Washington"] 
 .groupby("variant")["date"] 
 .min()) # First date of Omicron observation

In [None]:
(raw_seq
 [raw_seq.location=="Tennessee"] 
 .groupby("variant")["date"] 
 .min()) # First date of Omicron observation

In [None]:
(raw_seq
 [(raw_seq.location=="Washington") & (raw_seq.variant == "Alpha")]
)

In [None]:
# Defining Lineage Models
seed_L = 14
forecast_L = 0

# Get delays
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'Omicron', 'other']

gen = rf.pad_delays(
    [rf.discretise_gamma(mn=5.2, std=1.2), # Alpha
     rf.discretise_gamma(mn=5.2, std=1.2), # Beta
     rf.discretise_gamma(mn=3.6, std=1.2), # Delta
     rf.discretise_gamma(mn=5.2, std=1.2), # Epsilon
     rf.discretise_gamma(mn=5.2, std=1.2), # Gamma
     rf.discretise_gamma(mn=5.2, std=1.2), # Iota
     rf.discretise_gamma(mn=5.2, std=1.2), # Mu
     rf.discretise_gamma(mn=3.2, std=1.2), # Omicron
     rf.discretise_gamma(mn=5.2, std=1.2)] # Other
    )

delays = rf.pad_delays([rf.discretise_lognorm(mn=3.1, std=1.0)])

k_GARW = 20 # Number of spline basis elements

# Pick likelihoods
## R Likelihoods
GARW = rf.GARW(0.1, 0.01, prior_family="Normal")

CLik = rf.ZINegBinomCases(0.05) # Case likelihood
SLik = rf.DirMultinomialSeq(100) # Sequence count likelihood

# Defining models
LM_GARW = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=k_GARW, RLik = GARW, CLik = CLik, SLik = SLik,  v_names = v_names)

In [None]:
# Params for fitting
opt = numpyro.optim.Adam(step_size=4e-4)

iters = 60_000
num_samples = 1000
save = True
load = False

In [None]:
# Paths for export
path_base = f"../estimates/{data_name}"
path_GARW = path_base + "/GARW"

rf.make_model_directories(path_GARW)

# Running models and exporting results

In [None]:
# Running GARW model
MP_GARW = rf.fit_SVI_locations(raw_cases, raw_seq, locations, 
                             LM_GARW, opt, 
                             iters=iters, num_samples=num_samples, save=save, load=load, path=path_GARW)

In [None]:
import matplotlib.pyplot as plt


fig = plt.figure(figsize=(12.5, 7.5))
gs = fig.add_gridspec(nrows=1, ncols= 1)
ax = fig.add_subplot(gs[0, 0])

for country in MP_GARW.locator.keys():
    loss = MP_GARW.get(country).dataset["loss"]
    ax.plot(loss, label=country)
    
ax.set_yscale("log")
ax.set_xlabel("Iterations")
ax.set_ylabel("Loss")
ax.set_title("GARW")
ax.legend()

## Loading results

In [None]:
# Loading past results
def load_models(rc, rs, locations, RM, path=".", num_samples=1000):
    g, delays = rf.get_standard_delays()
    MP = rf.MultiPosterior()
    for i, loc in enumerate(locations):
        LD =rf.get_location_VariantData(rc, rs, loc)
        PH = rf.sample_loaded_posterior(LD, RM, num_samples=num_samples, path=path, name=loc)   
        MP.add_posterior(PH)
        print(f"Location {loc} finished {i+1} / {len(locations)}")
    return MP

In [None]:
MP_GARW = load_models(raw_cases, raw_seq, locations, LM_GARW, path=path_GARW, num_samples=3000)

In [None]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save

In [None]:
# Export GARW
R_GARW = rf.gather_R(MP_GARW, ps)
r_GARW = rf.gather_little_r(MP_GARW, ps)
I_GARW = rf.gather_I(MP_GARW, ps)
freq_GARW = rf.gather_freq(MP_GARW, ps)

R_GARW.to_csv(f"{path_base}/{data_name}_Rt-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
r_GARW.to_csv(f"{path_base}/{data_name}_little-r-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
I_GARW.to_csv(f"{path_base}/{data_name}_I-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
freq_GARW.to_csv(f"{path_base}/{data_name}_freq-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)

In [None]:
# Export growth advantages
ga_GARW = rf.gather_ga_time(MP_GARW, ps)
ga_GARW.to_csv(f"{path_base}/{data_name}_ga-combined-GARW.tsv",  encoding='utf-8', sep='\t', index=False)

In [None]:
# Figures for export
path_fig = path_base + "/figures"
rf.make_path_if_absent(path_fig)

In [None]:
from rt_from_frequency_dynamics.plotfunctions import *
ps = DefaultAes.ps
ps = [0.8]
alphas = DefaultAes.alphas

v_colors =["#2e5eaa", "#5adbff",  "#56e39f","#b4c5e4", "#f03a47",  "#f5bb00", "#9e4244","#9932CC", "#808080"] 
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'Omicron', 'other']
color_map = {v : c for c, v in zip(v_colors, v_names)}

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms

font = {'family' : 'Helvetica',
        'weight' : 'light',
        'size'   : 32}

matplotlib.rc('font', **font)

In [None]:
# Quick posterior checks

In [None]:
# Variant frequencies for each location from free R model
def posterior_frequencies_by_country(MP, ps, alphas, color_map, forecast=False):
    locations = list(MP.locator.keys())
    n_locations = len(locations)
        
    size = 10
    n_per_row = 5
    
    if n_locations % n_per_row == 0:
        n_rows = n_locations // n_per_row
    else:
        n_rows = (n_locations // n_per_row) + 1
    
    fig = plt.figure(figsize=(1.5*n_per_row*size, n_rows*size))
    gs = fig.add_gridspec(nrows=n_rows, ncols= n_per_row)
    
    ax_list = []
    for i, loc in enumerate(locations):
        this_col = i % n_per_row
        this_row = i // n_per_row
        ax = fig.add_subplot(gs[this_row, this_col])
        dataset, LD = rf.unpack_model(MP, loc)
        colors = [color_map[v] for v in LD.seq_names]
        
        rf.plot_posterior_frequency(ax, dataset, ps, alphas, colors, forecast=forecast)
        rf.plot_observed_frequency(ax, LD, colors)
        ax.set_title(loc)
    
        # Adding dates depends on whether we're forecasting
        if forecast:
            T_forecast = forecast_L
            ax.axvline(x=len(LD.dates)-1, color='k', linestyle='--')
            rf.add_dates_sep(ax, rf.expand_dates(LD.dates, T_forecast), sep=30)
        else:
            rf.add_dates(ax, LD.dates, sep=1)
            
        if this_col == 0:
            ax.set_ylabel("Variant frequency")
     
    # Make legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in color_map.items()]
    legend = fig.legend(patches, list(color_map.keys()), ncol=len(color_map.keys()), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1)
    return fig

In [None]:
fig_freq = posterior_frequencies_by_country(MP_GARW, ps, alphas, color_map, forecast=False)

In [None]:
# Variant cases for each location from free R model
def posterior_cases_by_country(MP, ps, alphas, color_map, forecast=False):
    locations = list(MP.locator.keys())
    n_locations = len(locations)
        
    size = 10
    n_per_row = 5
    
    if n_locations % n_per_row == 0:
        n_rows = n_locations // n_per_row
    else:
        n_rows = (n_locations // n_per_row) + 1
    
    fig = plt.figure(figsize=(1.9*n_per_row*size, n_rows*size))
    gs = fig.add_gridspec(nrows=n_rows, ncols= n_per_row)
    
    ax_list = []
    for i, loc in enumerate(locations):
        this_col = i % n_per_row
        this_row = i // n_per_row
        ax = fig.add_subplot(gs[this_row, this_col])
        dataset, LD = rf.unpack_model(MP, loc)
        colors = [color_map[v] for v in LD.seq_names]
        
        rf.plot_posterior_I(ax, dataset, ps, alphas, colors, forecast=forecast)
        rf.plot_cases(ax, LD)
        ax.set_title(loc)
        
        
        # Adding dates depends on whether we're forecasting
        if forecast:
            T_forecast = forecast_L
            ax.axvline(x=len(LD.dates)-1, color='k', linestyle='--')
            rf.add_dates_sep(ax, rf.expand_dates(LD.dates, T_forecast), sep=30)
        else:
            rf.add_dates(ax, LD.dates, sep=1)
            
        if this_col == 0:
            ax.set_ylabel("Variant cases")
     
    # Make legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in color_map.items()]
    legend = fig.legend(patches, list(color_map.keys()), ncol=len(color_map.keys()), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1)
    return fig

In [None]:
fig_cases = posterior_cases_by_country(MP_GARW, ps, alphas, color_map, forecast=False)

In [None]:
# Variant R for each location from free R model
def posterior_R_by_country(MP, ps, alphas, color_map, forecast=False):
    locations = list(MP.locator.keys())
    n_locations = len(locations)
        
    size = 4
    n_per_row = 5
    
    if n_locations % n_per_row == 0:
        n_rows = n_locations // n_per_row
    else:
        n_rows = (n_locations // n_per_row) + 1
    
    fig = plt.figure(figsize=(1.9*n_per_row*size, n_rows*size))
    gs = fig.add_gridspec(nrows=n_rows, ncols= n_per_row)
    
    ax_list = []
    for i, loc in enumerate(locations):
        this_col = i % n_per_row
        this_row = i // n_per_row
        ax = fig.add_subplot(gs[this_row, this_col])
        dataset, LD = rf.unpack_model(MP, loc)
        colors = [color_map[v] for v in LD.seq_names]
        
        rf.plot_R_censored(ax, dataset, ps, alphas, colors, thres=0.001, forecast=forecast)
        ax.set_title(loc)
        
        
        # Adding dates depends on whether we're forecasting
        if forecast:
            T_forecast = forecast_L
            ax.axvline(x=len(LD.dates)-1, color='k', linestyle='--')
            rf.add_dates_sep(ax, expand_dates(LD.dates, T_forecast), sep=14)
        else:
            rf.add_dates(ax, LD.dates, sep=1)
            
        if this_col == 0:
            ax.set_ylabel("Variant R")
     
    # Make legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in color_map.items()]
    legend = fig.legend(patches, list(color_map.keys()), ncol=len(color_map.keys()), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1)
    return fig

In [None]:
fig_R = posterior_R_by_country(MP_GARW, ps, alphas, color_map)

In [None]:
# Variant frequencies for each location from free R model
def posterior_epidemic_gr_by_country(MP, g, ps, alphas, color_map, forecast=False):
    locations = list(MP.locator.keys())
    n_locations = len(locations)
        
        
    size = 4
    n_per_row = 5
    
    if n_locations % n_per_row == 0:
        n_rows = n_locations // n_per_row
    else:
        n_rows = (n_locations // n_per_row) + 1
    
    fig = plt.figure(figsize=(1.9*n_per_row*size, n_rows*size))
    gs = fig.add_gridspec(nrows=n_rows, ncols= n_per_row)
    
    ax_list = []
    for i, loc in enumerate(locations):
        this_col = i % n_per_row
        this_row = i // n_per_row
        ax = fig.add_subplot(gs[this_row, this_col])
        dataset, LD = rf.unpack_model(MP, loc)
        colors = [color_map[v] for v in LD.seq_names]
        
        rf.plot_little_r_censored(ax, dataset, ps, alphas, colors, thres=0.02, forecast=forecast)
        ax.set_title(loc)
        
        
        # Adding dates depends on whether we're forecasting
        if forecast:
            T_forecast = forecast_L
            ax.axvline(x=len(LD.dates)-1, color='k', linestyle='--')
            rf.add_dates_sep(ax, rf.expand_dates(LD.dates, T_forecast), sep=14)
        else:
            rf.add_dates(ax, LD.dates, sep=1)
            
        if this_col == 0:
            ax.set_ylabel("Epidemic Growth Rate")
     
    # Make legend
    patches = [matplotlib.patches.Patch(color=c, label=l) for l, c in color_map.items()]
    legend = fig.legend(patches, list(color_map.keys()), ncol=len(color_map.keys()), loc="lower center")  
    legend.get_frame().set_linewidth(2.)
    legend.get_frame().set_edgecolor("k")
    fig.tight_layout()
    fig.subplots_adjust(bottom = 0.1)
    return fig

In [None]:
fig_little_r = posterior_epidemic_gr_by_country(MP_GARW, gen, ps, alphas, color_map)