In [None]:
import numpyro
import numpy as np
import pandas as pd
import jax.numpy as jnp

from rt_from_frequency_dynamics import discretise_gamma, discretise_lognorm, pad_delays
from rt_from_frequency_dynamics import get_standard_delays
from rt_from_frequency_dynamics import FreeGrowthModel, FixedGrowthModel

from rt_from_frequency_dynamics import get_location_LineageData
from rt_from_frequency_dynamics import fit_SVI_locations, MultiPosterior
from rt_from_frequency_dynamics import sample_loaded_posterior
from rt_from_frequency_dynamics import unpack_model
from rt_from_frequency_dynamics import make_path_if_absent, make_model_directories
from rt_from_frequency_dynamics import gather_free_Rt, gather_fixed_Rt

In [None]:
data_name = "variants-us"
raw_cases = pd.read_csv(f"../data/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

In [None]:
# Locations to run
locations = pd.unique(raw_seq["location"])

In [None]:
# Defining Lineage Models
seed_L = 7
forecast_L = 0

# Get delays
gen, delays = get_standard_delays()
LM_free = FreeGrowthModel(gen, delays, seed_L, forecast_L)
LM_fixed = FixedGrowthModel(gen, delays, seed_L, forecast_L)

In [None]:
# Params for fitting
opt = numpyro.optim.Adam(step_size=1e-2)
iters = 50_000
num_samples = 3000
save = True
load = False

In [None]:
# Paths for export
path_base = f"../estimates/{data_name}"
path_free = path_base + "/free"
path_fixed = path_base + "/fixed"

make_model_directories(path_free)
make_model_directories(path_fixed)

# Running models and exporting results

In [None]:
# Running free model
MP_free = fit_SVI_locations(raw_cases, raw_seq, locations, 
                            LM_free, opt, 
                            iters=iters, num_samples=num_samples, save=save, load=load, path=path_free)   

In [None]:
# Running fixed model
MP_fixed = fit_SVI_locations(raw_cases, raw_seq, locations, 
                             LM_fixed, opt, 
                             iters=iters, num_samples=num_samples, save=save, load=load, path=path_fixed)

## Loading results

In [None]:
# Loading past results
def load_models(rc, rs, locations, model_type, path=".", num_samples=1000):
    g, delays = get_standard_delays()
    LM = model_type(g, delays, 7, 0)
    MP = MultiPosterior()
    for i, loc in enumerate(locations):
        LD = get_location_LineageData(rc, rs, loc)
        PH = sample_loaded_posterior(LD, LM, num_samples=num_samples, path=path, name=loc)   
        MP.add_posterior(PH)
        print(f"Location {loc} finished {i+1} / {len(locations)}")
    return MP

In [None]:
MP_free = load_models(raw_cases, raw_seq, locations, FreeGrowthModel, path=path_free, num_samples=3000)
MP_fixed = load_models(raw_cases, raw_seq, locations, FixedGrowthModel, path=path_fixed, num_samples=3000)

In [None]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save
R_free, r_free = gather_free_Rt(MP_free, ps, g=gen, path=path_base, name=data_name)
R_fixed, r_fixed, ga_fixed = gather_fixed_Rt(MP_fixed, ps, path=path_base, name=data_name)