In [1]:
import numpyro
from numpyro.infer.autoguide import AutoMultivariateNormal
import pandas as pd

In [2]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
from src.rt_from_frequency_dynamics import *

# Run multiple states

In [3]:
raw_cases = pd.read_csv("../../rt-from-frequency-dynamics/data/location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv("../../rt-from-frequency-dynamics/data/location-variant-sequence-counts.tsv", sep="\t")

In [14]:
def run_SVI(LD, optimizer, num_samples=1000, iters=100_000, name="test", path=".", save=True, export=True, load=False):       
    X = make_breakpoint_splines(LD.cases, 20)
    data = LD.make_numpyro_input(X)

    # Defining model
    g, delays = get_standard_delays()
    LM = FixedGrowthModel(g, delays, 7, 0)

    # Run SVI
    SVIH = SVIHandler(optimizer=optimizer)
    guide = AutoMultivariateNormal(LM.model)
    
    if load:
        SVIH.load_state(f"{path}/models/{name}_svi.p")
    
    loss = SVIH.fit(LM.model, guide, data, iters, log_each=0)
    print(f"Model {name} finished. Final loss: {loss}")
    
    if jnp.isnan(loss):
        return False
    if save:
        SVIH.save_state(f"{path}/models/{name}_svi.p")
    
    # Get samples
    samples = SVIH.predict(LM.model, guide, data, num_samples=num_samples)
    dataset = to_arviz(samples)
    
    if export:
        # Get dataframes
        R_dataframe = pd.DataFrame(get_R(dataset, LD, ps, name))
        ga_dataframe = pd.DataFrame(get_growth_advantage(dataset, LD, ps, name))
    
        R_dataframe.to_csv(f"{path}/Rt/Rt_{name}.csv", encoding='utf-8', index=False)
        ga_dataframe.to_csv(f"{path}/ga/ga_{name}.csv", encoding='utf-8', index=False)
    return True

In [15]:
def get_state_LD(rc, rs, loc):
    rc_l = rc[rc.location == loc].copy()
    rs_l = rs[rs.location==loc].copy()
    return LineageData(rc_l, rs_l)

def run_locations(rc, rs, locations, optimizer, **kwargs):
    n_locations = len(locations)
    sucesses = []
    for i, loc in enumerate(locations):
        LD = get_state_LD(raw_cases, raw_seq, loc)
        model_name = loc.replace(" ", "_")
        sucess = run_SVI(LD, optimizer, name=model_name, **kwargs)
        sucesses.append(sucess)
        print(f'Location {loc} finished ({i+1}/{n_locations}).')
    return locations[sucesses]

In [16]:
ps = DefaultAes.ps

In [17]:
locations = pd.unique(raw_seq.location)

optimizer = numpyro.optim.Adam(step_size=1.0e-2)
num_samples = 1500
iters = 50_000
path = "../sims/all-states-preprint-fixed"
save = True
export = True
load = False
succeded = run_locations(raw_cases, raw_seq, locations, optimizer, 
              num_samples=num_samples, iters=iters, path=path, save=save, export=export, load=load)

Model Alabama finished. Final loss: 3424.47216796875
Location Alabama finished (1/39).
Model Alaska finished. Final loss: 2735.58056640625
Location Alaska finished (2/39).
Model Arizona finished. Final loss: 4603.634765625
Location Arizona finished (3/39).
Model Arkansas finished. Final loss: 2814.448486328125
Location Arkansas finished (4/39).
Model California finished. Final loss: 7058.9541015625
Location California finished (5/39).
Model Colorado finished. Final loss: 4430.9453125
Location Colorado finished (6/39).
Model Connecticut finished. Final loss: 3607.63232421875
Location Connecticut finished (7/39).
Model Florida finished. Final loss: 6995.302734375
Location Florida finished (8/39).
Model Georgia finished. Final loss: 4691.76171875
Location Georgia finished (9/39).
Model Idaho finished. Final loss: 2719.794921875
Location Idaho finished (10/39).
Model Illinois finished. Final loss: 4654.9423828125
Location Illinois finished (11/39).
Model Indiana finished. Final loss: 3819.

In [8]:
failed_locations = [l for l in locations if l not in succeded]
failed_locations

[]

In [18]:
def combine_exports(rc, rs, locations, path):   
    rt_list = []
    ga_list = []
    for i, loc in enumerate(locations):
        _loc = loc.replace(" ", "_")
        rt_list.append(pd.read_csv(f"{path}/Rt/Rt_{_loc}.csv"))
        ga_list.append(pd.read_csv(f"{path}/ga/ga_{_loc}.csv"))
    return pd.concat(rt_list), pd.concat(ga_list)

In [19]:
rt_df, ga_df = combine_exports(raw_cases, raw_seq, locations, path)

In [20]:
rt_df

Unnamed: 0,date,location,variant,median_R,median_freq,R_upper_95,R_lower_95,R_upper_80,R_lower_80,R_upper_50,R_lower_50
0,2021-01-02,Alabama,Alpha,0.836635,0.206123,1.106170,0.628429,1.010743,0.691329,0.910063,0.749228
1,2021-01-03,Alabama,Alpha,0.834950,0.206123,1.046187,0.653002,0.956509,0.696619,0.892448,0.760427
2,2021-01-04,Alabama,Alpha,0.830126,0.206125,1.006731,0.676193,0.936606,0.720686,0.875174,0.763146
3,2021-01-05,Alabama,Alpha,0.827720,0.206222,0.980682,0.697525,0.921473,0.740097,0.876328,0.779955
4,2021-01-06,Alabama,Alpha,0.826028,0.206807,0.958518,0.706894,0.901327,0.742232,0.867851,0.784863
...,...,...,...,...,...,...,...,...,...,...,...
2187,2021-09-27,Wisconsin,other,0.580588,0.000009,0.758534,0.428158,0.689381,0.471093,0.635712,0.527606
2188,2021-09-28,Wisconsin,other,0.575112,0.000008,0.784677,0.400988,0.699281,0.446163,0.626731,0.500327
2189,2021-09-29,Wisconsin,other,0.569884,0.000007,0.812579,0.362373,0.719739,0.420419,0.654209,0.503706
2190,2021-09-30,Wisconsin,other,0.565590,0.000007,0.864555,0.333752,0.720405,0.367581,0.639298,0.460430


In [21]:
ga_df[ga_df.variant == "Delta"]

Unnamed: 0,location,variant,median_ga,ga_upper_95,ga_lower_95,ga_upper_80,ga_lower_80,ga_upper_50,ga_lower_50
2,Alabama,Delta,1.444503,1.482694,1.409529,1.465111,1.41938,1.455859,1.431119
2,Alaska,Delta,2.049501,2.082941,2.017593,2.0706,2.02892,2.061109,2.039122
2,Arizona,Delta,1.804441,1.842422,1.766417,1.828771,1.779424,1.819133,1.793402
1,Arkansas,Delta,1.553488,1.594496,1.510766,1.58143,1.526939,1.568896,1.540415
2,California,Delta,1.693251,1.735197,1.655827,1.7193,1.666953,1.705103,1.677055
2,Colorado,Delta,1.855789,1.873478,1.83823,1.865448,1.843306,1.861626,1.849939
2,Connecticut,Delta,1.910323,2.004669,1.812613,1.973063,1.849209,1.948222,1.881351
2,Florida,Delta,1.911286,1.925594,1.898726,1.919027,1.901508,1.916526,1.907324
2,Georgia,Delta,1.796994,1.811794,1.780459,1.807067,1.786935,1.802571,1.792381
2,Idaho,Delta,1.659152,1.740138,1.580142,1.710763,1.6084,1.686381,1.630233


In [22]:
rt_df.to_csv(f"{path}/Rt/Rt_combined.csv", encoding='utf-8', index=False)
ga_df.to_csv(f"{path}/ga/ga_combined.csv", encoding='utf-8', index=False)