In [1]:
import numpyro
from numpyro.infer.autoguide import AutoMultivariateNormal
import pandas as pd

In [2]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
from src.rt_from_frequency_dynamics import *

# Run multiple states

In [3]:
raw_cases = pd.read_csv("../../rt-from-frequency-dynamics/data/location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv("../../rt-from-frequency-dynamics/data/location-variant-sequence-counts.tsv", sep="\t")

In [13]:
def run_SVI(LD, optimizer, num_samples=1000, iters=100_000, name="test", path=".", save=True, export=True, load=False):       
    X = make_breakpoint_splines(LD.cases, 20)
    data = LD.make_numpyro_input(X)

    # Defining model
    g, delays = get_standard_delays()
    LM = FreeGrowthModel(g, delays, 7, 0)

    # Run SVI
    SVIH = SVIHandler(optimizer=optimizer)
    guide = AutoMultivariateNormal(LM.model)
    
    if load:
        SVIH.load_state(f"{path}/models/{name}_svi.p")
    
    loss = SVIH.fit(LM.model, guide, data, iters, log_each=0)
    print(f"Model {name} finished. Final loss: {loss}")
    
    if jnp.isnan(loss):
        return False
    if save:
        SVIH.save_state(f"{path}/models/{name}_svi.p")
    
    # Get samples
    samples = SVIH.predict(LM.model, guide, data, num_samples=num_samples)
    dataset = to_arviz(samples)
    
    if export:
        # Get dataframes
        R_dataframe = pd.DataFrame(get_R(dataset, LD, ps, name))    
        R_dataframe.to_csv(f"{path}/Rt/Rt_{name}.csv", encoding='utf-8', index=False)
    return True

In [14]:
def get_state_LD(rc, rs, loc):
    rc_l = rc[rc.location == loc].copy()
    rs_l = rs[rs.location==loc].copy()
    return LineageData(rc_l, rs_l)

def run_locations(rc, rs, locations, optimizer, **kwargs):
    n_locations = len(locations)
    sucesses = []
    for i, loc in enumerate(locations):
        LD = get_state_LD(raw_cases, raw_seq, loc)
        model_name = loc.replace(" ", "_")
        sucess = run_SVI(LD, optimizer, name=model_name, **kwargs)
        sucesses.append(sucess)
        print(f'Location {loc} finished ({i+1}/{n_locations}).')
    return locations[sucesses]

In [15]:
ps = DefaultAes.ps

In [16]:
locations = pd.unique(raw_seq.location)

optimizer = numpyro.optim.Adam(step_size=1.0e-2)
num_samples = 1500
iters = 50_000
path = "../sims/all-states-preprint-free"
save = True
export = True
load = False
succeded = run_locations(raw_cases, raw_seq, locations, optimizer, 
              num_samples=num_samples, iters=iters, path=path, save=save, export=export, load=load)

Model Alabama finished. Final loss: 3405.12548828125
Location Alabama finished (1/39).
Model Alaska finished. Final loss: 2793.943603515625
Location Alaska finished (2/39).
Model Arizona finished. Final loss: 4381.81787109375
Location Arizona finished (3/39).
Model Arkansas finished. Final loss: 2813.24951171875
Location Arkansas finished (4/39).
Model California finished. Final loss: 6552.666015625
Location California finished (5/39).
Model Colorado finished. Final loss: 4281.517578125
Location Colorado finished (6/39).
Model Connecticut finished. Final loss: 3606.1328125
Location Connecticut finished (7/39).
Model Florida finished. Final loss: 6355.2958984375
Location Florida finished (8/39).
Model Georgia finished. Final loss: 4460.96484375
Location Georgia finished (9/39).
Model Idaho finished. Final loss: 2717.18017578125
Location Idaho finished (10/39).
Model Illinois finished. Final loss: 4412.9169921875
Location Illinois finished (11/39).
Model Indiana finished. Final loss: 378

In [17]:
failed_locations = [l for l in locations if l not in succeded]
failed_locations

[]

In [18]:
def combine_exports(rc, rs, locations, path):   
    rt_list = []
    for i, loc in enumerate(locations):
        _loc = loc.replace(" ", "_")
        rt_list.append(pd.read_csv(f"{path}/Rt/Rt_{_loc}.csv"))
    return pd.concat(rt_list)

In [19]:
rt_df = combine_exports(raw_cases, raw_seq, locations, path)

In [20]:
rt_df

Unnamed: 0,date,location,variant,median_R,median_freq,R_upper_95,R_lower_95,R_upper_80,R_lower_80,R_upper_50,R_lower_50
0,2021-01-02,Alabama,Alpha,1.073730,0.018817,1.550390,0.701531,1.355073,0.811848,1.198824,0.912277
1,2021-01-03,Alabama,Alpha,1.096541,0.018817,1.524041,0.767612,1.351724,0.857391,1.210157,0.952431
2,2021-01-04,Alabama,Alpha,1.117060,0.018819,1.496823,0.774870,1.359018,0.892498,1.196342,0.952845
3,2021-01-05,Alabama,Alpha,1.130633,0.018870,1.505939,0.795053,1.365812,0.915992,1.230100,0.994914
4,2021-01-06,Alabama,Alpha,1.142889,0.019105,1.526867,0.817361,1.364859,0.924115,1.218689,0.984211
...,...,...,...,...,...,...,...,...,...,...,...
2187,2021-09-27,Wisconsin,other,1.284883,0.018166,1.487702,1.079312,1.410236,1.145974,1.339500,1.206707
2188,2021-09-28,Wisconsin,other,1.276162,0.018601,1.501597,1.060867,1.427426,1.143375,1.314946,1.176185
2189,2021-09-29,Wisconsin,other,1.263589,0.019162,1.528038,1.039983,1.409146,1.099998,1.316110,1.164876
2190,2021-09-30,Wisconsin,other,1.252354,0.019754,1.565484,1.010273,1.414665,1.077569,1.333336,1.165390


In [21]:
rt_df.to_csv(f"{path}/Rt/Rt_combined.csv", encoding='utf-8', index=False)