In [1]:
import numpyro
import numpy as np
import pandas as pd
import jax.numpy as jnp

import rt_from_frequency_dynamics as rf



In [2]:
data_name = "pango-countries"
raw_cases = pd.read_csv(f"../data/{data_name}/{data_name}_location-case-counts.tsv", sep="\t")
raw_seq = pd.read_csv(f"../data/{data_name}/{data_name}_location-variant-sequence-counts.tsv", sep="\t")

In [3]:
# Locations to run
locations = pd.unique(raw_seq["location"])

In [4]:
locations

array(['USA'], dtype=object)

In [5]:
v_names = pd.unique(raw_seq["variant"])

In [6]:
v_names

array(['other', 'B.1.1.529', 'BA.2', 'BA.2.12', 'BA.2.12.1', 'BA.2.13',
       'BA.2.13.1', 'BA.2.18', 'BA.2.3', 'BA.2.3.20', 'BA.2.38',
       'BA.2.56', 'BA.2.75', 'BA.2.75.1', 'BA.2.75.2', 'BA.2.75.3.1',
       'BA.2.75.5', 'BA.2.76', 'BA.2.9', 'BA.4', 'BA.4.1', 'BA.4.1.1',
       'BA.4.1.6', 'BA.4.1.8', 'BA.4.1.9', 'BA.4.2', 'BA.4.3', 'BA.4.4',
       'BA.4.6', 'BA.5', 'BA.5.1', 'BA.5.10.1', 'BA.5.1.1', 'BA.5.1.10',
       'BA.5.1.12', 'BA.5.1.18', 'BA.5.1.2', 'BA.5.1.21', 'BA.5.1.22',
       'BA.5.1.23', 'BA.5.1.24', 'BA.5.1.25', 'BA.5.1.27', 'BA.5.1.3',
       'BA.5.1.5', 'BA.5.1.6', 'BA.5.1.7', 'BA.5.1.8', 'BA.5.2',
       'BA.5.2.1', 'BA.5.2.16', 'BA.5.2.18', 'BA.5.2.19', 'BA.5.2.2',
       'BA.5.2.20', 'BA.5.2.21', 'BA.5.2.22', 'BA.5.2.23', 'BA.5.2.26',
       'BA.5.2.27', 'BA.5.2.28', 'BA.5.2.3', 'BA.5.2.31', 'BA.5.2.33',
       'BA.5.2.34', 'BA.5.2.6', 'BA.5.2.7', 'BA.5.2.8', 'BA.5.2.9',
       'BA.5.3', 'BA.5.3.1', 'BA.5.5', 'BA.5.5.1', 'BA.5.5.2', 'BA.5.5.3',
       'BA.5.

In [7]:
# Defining Lineage Models
seed_L = 14
forecast_L = 7

gen = rf.discretise_gamma(mn=3.1, std=1.2)

delays = rf.pad_delays([rf.discretise_lognorm(mn=3.1, std=1.0)])

LM_GARW = rf.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                       RLik = rf.GARW(1e-2, 1e-3, prior_family="Normal"),
                       CLik = rf.ZINegBinomCases(0.05),
                       SLik = rf.DirMultinomialSeq(100),
                       v_names = v_names)

In [8]:
# Params for fitting
opt = numpyro.optim.Adam(step_size=1e-3)
iters = 100_000
num_samples = 3000
save = True
load = False

In [9]:
# Paths for export
path_base = f"../estimates/{data_name}"
path_GARW = path_base + "/GARW"

rf.make_model_directories(path_GARW)

# Running models and exporting results

In [10]:
# Running GARW model
MP_GARW = rf.fit_SVI_locations(raw_cases, raw_seq, locations, 
                             LM_GARW, opt, 
                             iters=iters, num_samples=num_samples, save=save, load=load, path=path_GARW)

Location USA finished (1/1).


## Loading results

In [11]:
# Loading past results
def load_models(rc, rs, locations, RM, path=".", num_samples=1000):
    MP = rf.MultiPosterior()
    for i, loc in enumerate(locations):
        LD = rf.get_location_VariantData(rc, rs, loc)
        PH = rf.sample_loaded_posterior(LD, RM, num_samples=num_samples, path=path, name=loc)   
        MP.add_posterior(PH)
        print(f"Location {loc} finished {i+1} / {len(locations)}")
    return MP

In [12]:
MP_GARW = load_models(raw_cases, raw_seq, locations, LM_GARW, path=path_GARW, num_samples=3000)

Location USA finished 1 / 1


In [13]:
# Exporting growth info
ps = [0.95, 0.8, 0.5] # Which credible intevals to save

In [14]:
# Export GARW
R_GARW = rf.gather_R(MP_GARW, ps)
r_GARW = rf.gather_little_r(MP_GARW, ps)
I_GARW = rf.gather_I(MP_GARW, ps)
freq_GARW = rf.gather_freq(MP_GARW, ps)

R_GARW.to_csv(f"{path_base}/{data_name}_Rt-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
r_GARW.to_csv(f"{path_base}/{data_name}_little-r-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
I_GARW.to_csv(f"{path_base}/{data_name}_I-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)
freq_GARW.to_csv(f"{path_base}/{data_name}_freq-combined-GARW.tsv", encoding='utf-8', sep='\t', index=False)

In [15]:
# Export Forecasts
R_GARW_f = rf.gather_R(MP_GARW, ps, forecast=True)
r_GARW_f = rf.gather_little_r(MP_GARW, ps, forecast=True)
I_GARW_f = rf.gather_I(MP_GARW, ps, forecast=True)
freq_GARW_f = rf.gather_freq(MP_GARW, ps, forecast=True)

R_GARW_f.to_csv(f"{path_base}/{data_name}_Rt-combined-forecast-GARW.tsv", encoding='utf-8', sep='\t', index=False)
r_GARW_f.to_csv(f"{path_base}/{data_name}_little-r-combined-forecast-GARW.tsv", encoding='utf-8', sep='\t', index=False)
I_GARW_f.to_csv(f"{path_base}/{data_name}_I-combined-forecast-GARW.tsv", encoding='utf-8', sep='\t', index=False)
freq_GARW_f.to_csv(f"{path_base}/{data_name}_freq-combined-forecast-GARW.tsv", encoding='utf-8', sep='\t', index=False)