In [1]:
import numpyro
import pandas as pd
import jax.numpy as jnp

from rt_from_frequency_dynamics import discretise_gamma
from rt_from_frequency_dynamics import get_standard_delays
from rt_from_frequency_dynamics import FreeGrowthModel, FixedGrowthModel

from rt_from_frequency_dynamics import get_location_LineageData
from rt_from_frequency_dynamics import fit_SVI, MultiPosterior
from rt_from_frequency_dynamics import make_model_directories
from rt_from_frequency_dynamics import gather_free_Rt, gather_fixed_Rt



In [2]:
data_name = "variants-us"
raw_cases = pd.read_csv(f"../data/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

In [3]:
# Defining Fixed parameters
seed_L = 7
forecast_L = 0   
_, delays = get_standard_delays()


# Data will be constant between models
LD = get_location_LineageData(raw_cases, raw_seq, "Washington")

In [4]:
# Shared fitting params
opt = numpyro.optim.Adam(step_size=1.0e-2)
iters = 50_000
num_samples = 3000
save = True
load = False

## Varying means of the generation time

In [5]:
def fit_SVI_sensitivity_mean(LD, modelclass, mns, opt, **fit_kwargs):
    n_mns = len(mns)
    MP = MultiPosterior()
    for i, mn in enumerate(mns):
        g = discretise_gamma(mn=mn, std=1.72)
        LM = modelclass(g, delays, seed_L, forecast_L)
        PH = fit_SVI(LD, LM, opt, name=f"g_mean_{mn}", **fit_kwargs)
        MP.add_posterior(PH)
        print(f'Finished ({i+1}/{n_mns}).')
    return MP

In [6]:
# mns to loop over
mns = jnp.arange(2.0, 6.01, 0.5)
model_name_mean = "variants-us-sensitivity-means"
path_mean = f"../estimates/{model_name_mean}"

In [7]:
# Free Model settings
path_free = path_mean + "/free"
make_model_directories(path_free)
MP_free = fit_SVI_sensitivity_mean(LD, FreeGrowthModel, mns, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_free)

Finished (1/9).
Finished (2/9).
Finished (3/9).
Finished (4/9).
Finished (5/9).
Finished (6/9).
Finished (7/9).
Finished (8/9).
Finished (9/9).


In [8]:
# Fixed model settings
path_fixed = path_mean + "/fixed"
make_model_directories(path_fixed)
MP_fixed = fit_SVI_sensitivity_mean(LD, FixedGrowthModel, mns, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_fixed)

Finished (1/9).
Finished (2/9).
Finished (3/9).
Finished (4/9).
Finished (5/9).
Finished (6/9).
Finished (7/9).
Finished (8/9).
Finished (9/9).


In [9]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save

In [10]:
R_free = gather_free_Rt(MP_free, ps, path=path_mean, name=model_name_mean)

In [11]:
R_fixed, ga_fixed = gather_fixed_Rt(MP_fixed, ps, path=path_mean, name=model_name_mean)

#  Varying standard deviation of the generation time

In [12]:
def fit_SVI_sensitivity_sd(LD, modelclass, sds, opt, **fit_kwargs):
    n_sd = len(sds)
    MP = MultiPosterior()
    for i, sd in enumerate(sds):
        g = discretise_gamma(mn=5.2, std=sd)
        LM = modelclass(g, delays, seed_L, forecast_L)
        PH = fit_SVI(LD, LM, opt, name=f"g_sd_{sd}", **fit_kwargs)
        MP.add_posterior(PH)
        print(f'Finished ({i+1}/{n_sd}).')
    return MP

In [13]:
# sds to loop over
sds = jnp.arange(1.0, 6.01, 0.5)

In [14]:
model_name_sd = "variants-us-sensitivity-sd"
path_sd = f"../estimates/{model_name_sd}"

# Free Model settings
path_free = path_sd + "/free"
make_model_directories(path_free)
MP_free = fit_SVI_sensitivity_sd(LD, FreeGrowthModel, sds, opt, 
                                iters=iters, num_samples=num_samples, save=save, load=load, path=path_free)

# Fixed model settings
path_fixed = path_sd + "/fixed"
make_model_directories(path_fixed)
MP_fixed = fit_SVI_sensitivity_sd(LD, FixedGrowthModel, sds, opt, 
                                 iters=iters, num_samples=num_samples, save=save, load=load, path=path_fixed)

# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save
R_free = gather_free_Rt(MP_free, ps, path=path_sd, name=model_name_sd)
R_fixed, ga_fixed = gather_fixed_Rt(MP_fixed, ps, path=path_sd, name=model_name_sd)

Finished (1/11).
Finished (2/11).
Finished (3/11).
Finished (4/11).
Finished (5/11).
Finished (6/11).
Finished (7/11).
Finished (8/11).
Finished (9/11).
Finished (10/11).
Finished (11/11).
Finished (1/11).
Finished (2/11).
Finished (3/11).
Finished (4/11).
Finished (5/11).
Finished (6/11).
Finished (7/11).
Finished (8/11).
Finished (9/11).
Finished (10/11).
Finished (11/11).
