In [15]:
import numpyro
import pandas as pd
import jax.numpy as jnp

import rt_from_frequency_dynamics as rf
from rt_from_frequency_dynamics import discretise_gamma
from rt_from_frequency_dynamics import get_standard_delays
from rt_from_frequency_dynamics import FreeGrowthModel, FixedGrowthModel

from rt_from_frequency_dynamics import get_location_VariantData
from rt_from_frequency_dynamics import fit_SVI, MultiPosterior
from rt_from_frequency_dynamics import make_model_directories
from rt_from_frequency_dynamics import gather_R

In [16]:
data_name = "omicron-us"
raw_cases = pd.read_csv(f"../data/{data_name}/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

In [18]:
# Defining Lineage Models
seed_L = 14
forecast_L = 7

# Get delays
v_names = ['Delta', 'Omicron', 'other']
gen = rf.pad_delays(
    [rf.discretise_gamma(mn=4.4, std=1.2), # Delta
     rf.discretise_gamma(mn=3.1, std=1.2), # Omicron
     rf.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = rf.pad_delays([rf.discretise_lognorm(mn=3.1, std=1.0)])

GARW = rf.GARW(0.1, 0.1)
FGA = rf.FixedGA(1.0)
CLik = rf.ZINegBinomCases(0.05)
SLik = rf.DirMultinomialSeq()
            
# Data will be constant between models
LD = get_location_VariantData(raw_cases, raw_seq, "Washington")

In [19]:
# Shared fitting params
opt = numpyro.optim.Adam(step_size=1e-3)
iters = 50_000
num_samples = 3000
save = True
load = False

## Varying means of the generation time

In [20]:
def fit_SVI_sensitivity_mean(LD, RLik, mns, opt, **fit_kwargs):
    n_mns = len(mns)
    MP = MultiPosterior()
    for i, mn in enumerate(mns):
        gen = rf.pad_delays(
            [rf.discretise_gamma(mn=4.4, std=1.2), # Delta
             rf.discretise_gamma(mn=mn, std=1.2), # Omicron
             rf.discretise_gamma(mn=4.4, std=1.2)] # Other
        )
        LM = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                            RLik = RLik,
                            CLik = CLik,
                            SLik = SLik,
                            v_names = v_names)
        PH = fit_SVI(LD, LM, opt, name=f"g_mean_{mn}", **fit_kwargs)
        MP.add_posterior(PH)
        print(f'Finished ({i+1}/{n_mns}).')
    return MP

In [21]:
# mns to loop over
mns = jnp.arange(2.0, 6.01, 0.5)
model_name_mean = "omicron-us-sensitivity-means"
path_mean = f"../estimates/{model_name_mean}"

In [23]:
# Free Model settings
path_GARW = path_mean + "/GARW"
make_model_directories(path_GARW)
MP_GARW = fit_SVI_sensitivity_mean(LD, GARW, mns, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_GARW)

Finished (1/9).
Finished (2/9).
Finished (3/9).
Finished (4/9).
Finished (5/9).
Finished (6/9).
Finished (7/9).
Finished (8/9).
Finished (9/9).


In [24]:
# Fixed model settings
path_fixed = path_mean + "/fixed"
make_model_directories(path_fixed)
MP_fixed = fit_SVI_sensitivity_mean(LD, FGA, mns, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_fixed)

../estimates/omicron-us-sensitivity-means/fixed created.
../estimates/omicron-us-sensitivity-means/fixed/models created.
Finished (1/9).
Finished (2/9).
Finished (3/9).
Finished (4/9).
Finished (5/9).
Finished (6/9).
Finished (7/9).
Finished (8/9).
Finished (9/9).


In [26]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save

Unnamed: 0,date,location,variant,median_R,median_freq,R_upper_95,R_lower_95,R_upper_80,R_lower_80,R_upper_50,R_lower_50
0,2021-11-15,g_mean_2.0,Delta,0.585659,9.979598e-01,0.986828,0.2795903,0.8187542,0.36371708,0.6516364,0.4207992
1,2021-11-16,g_mean_2.0,Delta,0.603599,9.974052e-01,0.93598783,0.3706891,0.7760973,0.4146602,0.67109406,0.48543242
2,2021-11-17,g_mean_2.0,Delta,0.627997,9.966645e-01,0.87631667,0.42650717,0.76759183,0.47843435,0.6825261,0.53335714
3,2021-11-18,g_mean_2.0,Delta,0.655303,9.955301e-01,0.8650179,0.48425746,0.7738236,0.53520864,0.7159328,0.5925176
4,2021-11-19,g_mean_2.0,Delta,0.687580,9.948864e-01,0.8621847,0.52686834,0.7897701,0.5768864,0.7262682,0.6171513
...,...,...,...,...,...,...,...,...,...,...,...
400,2022-03-25,g_mean_6.0,other,0.437151,8.460582e-10,0.81218183,0.19134259,0.6283551,0.24888597,0.4942946,0.29789236
401,2022-03-26,g_mean_6.0,other,0.439132,7.372225e-10,0.83951926,0.18203819,0.64366174,0.24632107,0.4997559,0.29368827
402,2022-03-27,g_mean_6.0,other,0.442189,6.426140e-10,0.8681673,0.17187399,0.64594156,0.22799823,0.4799041,0.26428822
403,2022-03-28,g_mean_6.0,other,0.446725,5.500489e-10,0.8862831,0.1466332,0.6759152,0.23107922,0.4938985,0.2678355


In [27]:
R_GARW = rf.gather_R(MP_GARW, ps)

# Save files
R_GARW.to_csv(f"{path_mean}/{model_name_mean}_Rt-combined-GARW.tsv",  encoding='utf-8', sep='\t', index=False)

In [28]:
R_fixed =  rf.gather_R(MP_fixed, ps)
ga_fixed = rf.gather_ga(MP_fixed, ps)

# Save files
R_fixed.to_csv(f"{path_mean}/{model_name_mean}_Rt-combined-fixed.tsv",  encoding='utf-8', sep='\t', index=False)
ga_fixed.to_csv(f"{path_mean}/{model_name_mean}_ga-combined-fixed.tsv",  encoding='utf-8', sep='\t', index=False)

#  Varying standard deviation of the generation time

In [29]:
def fit_SVI_sensitivity_sd(LD, RLik, sds, opt, **fit_kwargs):
    n_sd = len(sds)
    MP = MultiPosterior()
    for i, sd in enumerate(sds):
        gen = rf.pad_delays(
            [rf.discretise_gamma(mn=4.4, std=1.2), # Delta
             rf.discretise_gamma(mn=3.1, std=sd), # Omicron
             rf.discretise_gamma(mn=4.4, std=1.2)] # Other
        )
        LM = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                            RLik = RLik,
                            CLik = CLik,
                            SLik = SLik,
                            v_names = v_names)
        PH = fit_SVI(LD, LM, opt, name=f"g_sd_{sd}", **fit_kwargs)
        MP.add_posterior(PH)
        print(f'Finished ({i+1}/{n_sd}).')
    return MP

In [30]:
# sds to loop over
sds = jnp.arange(1.0, 6.01, 0.5)

In [33]:
model_name_sd = "omicron-us-sensitivity-sd"
path_sd = f"../estimates/{model_name_sd}"

In [34]:
# Free Model settings
path_GARW = path_sd + "/GARW"
make_model_directories(path_GARW)
MP_GARW= fit_SVI_sensitivity_sd(LD, GARW, sds, opt, 
                                iters=iters, num_samples=num_samples, save=save, load=load, path=path_GARW)

../estimates/omicron-us-sensitivity-sd/GARW created.
../estimates/omicron-us-sensitivity-sd/GARW/models created.
Finished (1/11).
Finished (2/11).
Finished (3/11).
Finished (4/11).
Finished (5/11).
Finished (6/11).
Finished (7/11).
Finished (8/11).
Finished (9/11).
Finished (10/11).
Finished (11/11).


In [35]:
# Fixed model settings
path_fixed = path_sd + "/fixed"
make_model_directories(path_fixed)
MP_fixed = fit_SVI_sensitivity_sd(LD, FGA, sds, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_fixed)

../estimates/omicron-us-sensitivity-sd/fixed created.
../estimates/omicron-us-sensitivity-sd/fixed/models created.
Finished (1/11).
Finished (2/11).
Finished (3/11).
Finished (4/11).
Finished (5/11).
Finished (6/11).
Finished (7/11).
Finished (8/11).
Finished (9/11).
Finished (10/11).
Finished (11/11).


In [37]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save

In [38]:
# Export R for GARW

R_GARW = rf.gather_R(MP_GARW, ps)

# Save files
R_GARW.to_csv(f"{path_sd}/{model_name_sd}_Rt-combined-GARW.tsv",  encoding='utf-8', sep='\t', index=False)

In [39]:
# Export R and ga for FGA

R_fixed =  rf.gather_R(MP_fixed, ps)
ga_fixed = rf.gather_ga(MP_fixed, ps)

# Save files
R_fixed.to_csv(f"{path_sd}/{model_name_sd}_Rt-combined-fixed.tsv",  encoding='utf-8', sep='\t', index=False)
ga_fixed.to_csv(f"{path_sd}/{model_name_sd}_ga-combined-fixed.tsv",  encoding='utf-8', sep='\t', index=False)