In [1]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/renewal.git

In [2]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from plotly.express.colors import qualitative as qual_colours
import numpy as np

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel, ModelResult
from emu_renewal.outputs import get_spaghetti, get_quant_df_from_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.targets import StandardTarget

In [3]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
window_len = 32
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
analysis_start = datetime(2021, 3, 1)
analysis_end = datetime(2021, 11, 1)
init_start = analysis_start - timedelta(window_len)
init_end = analysis_start - timedelta(1)
select_data = mys_data.loc[analysis_start: analysis_end]
init_data = mys_data.loc[init_start: init_end]

In [4]:
renew_model = RenewalModel(33e6, analysis_start, analysis_end, proc_update_freq, CosineMultiCurve(), GammaDens(), window_len, init_data, GammaDens())

In [5]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Uniform(6.5, 10.5),
    "gen_sd": dist.Uniform(3.0, 4.6),
    "cdr": dist.Beta(4.0, 10.0),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.Uniform(8.0, 12.0),
    "report_sd": dist.Uniform(3.0, 6.0),
}

In [6]:
targets = {
    "cases": StandardTarget(select_data, 0.1)
}

In [7]:
calib = StandardCalib(renew_model, priors, targets)
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=calib.custom_init(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=100, num_warmup=100)
mcmc.run(random.PRNGKey(1))

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=200)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [9]:
spaghetti = new_new_get_spaghetti(calib, sample_params)
key_outputs = ["cases", "suscept", "r_t", "process"]
quantiles = get_quant_df_from_spaghetti(spaghetti, quantiles=[0.05, 0.5, 0.95])

In [10]:
spaghetti

Unnamed: 0_level_0,cases,cases,cases,cases,cases,cases,cases,cases,cases,cases,...,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum
Unnamed: 0_level_1,"(0, 0)","(0, 1)","(0, 10)","(0, 11)","(0, 12)","(0, 13)","(0, 14)","(0, 15)","(0, 16)","(0, 17)",...,"(1, 90)","(1, 91)","(1, 92)","(1, 93)","(1, 94)","(1, 95)","(1, 96)","(1, 97)","(1, 98)","(1, 99)"
2021-03-01,2616.844566,2639.830524,2618.284816,2624.858531,2632.365207,2631.405249,2618.958897,2657.873422,2643.512926,2642.288753,...,19458.358016,19529.745029,19480.783350,19523.982496,19507.533343,19479.627690,19475.293756,19501.727391,19458.267846,19480.486144
2021-03-02,2517.502674,2544.120519,2522.353215,2535.560466,2539.180307,2537.453765,2521.356786,2567.496058,2550.020320,2554.040968,...,19088.274622,19182.332552,19120.175515,19173.093121,19142.877591,19117.146257,19094.410591,19132.824459,19094.393874,19129.679121
2021-03-03,2409.121006,2436.083405,2417.750032,2437.656339,2435.094474,2432.662592,2414.775894,2462.505467,2444.440775,2455.195702,...,18625.099416,18768.291774,18676.761552,18749.897340,18692.368631,18669.916925,18612.524047,18674.029922,18641.479642,18703.746063
2021-03-04,2299.246513,2324.344893,2311.334201,2337.526713,2327.745696,2324.747258,2306.418943,2351.002889,2335.333626,2352.600263,...,18080.843550,18269.868672,18150.703769,18250.441566,18156.954876,18142.340536,18040.693379,18129.355100,18112.171408,18207.294436
2021-03-05,2191.749324,2214.239970,2206.507677,2238.543170,2221.534311,2218.126294,2199.934295,2238.554695,2228.131391,2250.354137,...,17485.400988,17706.837027,17568.725805,17700.402685,17564.345801,17563.026853,17410.003237,17527.803797,17535.307925,17668.153622
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2021-10-28,5518.021317,5705.125140,5463.831296,5270.641861,5834.033181,5587.050553,5480.567008,5618.369314,5590.509189,5524.002803,...,40736.507028,42018.387983,38918.022085,40545.354248,39429.972556,40616.680763,41030.647882,42128.372206,40375.628289,39060.295478
2021-10-29,5350.047486,5644.227640,5429.100843,5154.922633,5791.823743,5540.512606,5477.532308,5579.127231,5516.630181,5389.026759,...,40311.397762,41030.345673,38615.310294,39488.068055,39135.763442,39829.874528,40608.555679,40894.174835,39686.936073,38382.516371
2021-10-30,5177.997538,5581.007329,5401.495274,5036.114010,5758.359262,5506.309008,5492.857395,5548.816930,5445.921895,5252.676144,...,40009.517285,39950.350364,38433.073625,38384.479773,38953.887055,38986.711539,40246.374420,39600.610491,39060.915482,37737.357678
2021-10-31,5003.897716,5515.989205,5379.552247,4915.128605,5732.214679,5482.466473,5523.639566,5526.141428,5377.743367,5115.751186,...,39824.943627,38784.454328,38368.201976,37239.287776,38879.885055,38092.134538,39937.699618,38257.851213,38494.266693,37121.944144


In [11]:
quantiles=[0.025, 0.5]
outputs = set(spaghetti.columns.get_level_values(0))
column_names = pd.MultiIndex.from_product([outputs, quantiles])
quantiles_df = pd.DataFrame(index=spaghetti.index, columns=column_names)
for col in outputs:
    quantiles_df[col] = spaghetti[col].quantile(quantiles, axis=1).T

In [12]:
quantiles_df

Unnamed: 0_level_0,r_t,r_t,seropos,seropos,suscept,suscept,cases,cases,process,process,incidence,incidence,weekly_sum,weekly_sum
Unnamed: 0_level_1,0.025,0.500,0.025,0.500,0.025,0.500,0.025,0.500,0.025,0.500,0.025,0.500,0.025,0.500
2021-03-01,0.521939,0.594249,0.000094,0.000197,3.186043e+07,3.256943e+07,2618.261005,2634.110120,0.534184,0.602412,3080.480842,6402.427917,19453.950076,19510.442384
2021-03-02,0.521697,0.594097,0.000186,0.000389,3.184458e+07,3.256301e+07,2517.485156,2543.286024,0.534184,0.602412,3029.699636,6278.162513,19083.484713,19155.468932
2021-03-03,0.521457,0.593947,0.000277,0.000577,3.182899e+07,3.255675e+07,2403.760025,2440.043799,0.534184,0.602412,2969.413519,6168.307875,18612.499706,18724.833504
2021-03-04,0.521222,0.593800,0.000365,0.000762,3.181373e+07,3.255069e+07,2285.999182,2332.933031,0.534184,0.602412,2893.012157,6015.938215,18059.804309,18206.812418
2021-03-05,0.520991,0.593657,0.000450,0.000943,3.179891e+07,3.254485e+07,2171.478843,2226.662761,0.534184,0.602412,2796.807245,5751.151512,17434.751924,17624.879444
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2021-10-28,0.701748,0.898160,0.127353,0.267013,8.983589e+06,2.387983e+07,5054.889595,5498.716952,0.868969,1.268174,9826.956320,19780.160521,37611.859231,40135.837628
2021-10-29,0.681591,0.898977,0.127724,0.267564,8.941758e+06,2.386180e+07,4972.412638,5428.494107,0.853098,1.281045,9489.242934,19449.080065,36834.676224,39566.521793
2021-10-30,0.668103,0.902447,0.128101,0.268105,8.901848e+06,2.384400e+07,4847.957345,5365.990712,0.843514,1.292994,9029.809773,19424.807119,36160.788335,39102.411175
2021-10-31,0.659288,0.903298,0.128483,0.268624,8.863669e+06,2.382664e+07,4712.960890,5310.301706,0.837619,1.299100,8710.471616,19393.300202,35491.980717,38524.628808


In [13]:
plot_uncertainty_patches(quantiles, select_data, qual_colours.Plotly, outputs=key_outputs).update_layout(showlegend=False)

TypeError: list indices must be integers or slices, not str

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, list(priors.keys()), priors);

In [None]:
parameters = {i: j for i, j in [k for k in sample_params.iterrows()][0][1].items() if "dispersion" not in i}
result = renew_model.renewal_func(**parameters)

In [None]:
pd.DataFrame(
    {
        "cases": np.array(result.cases),
        "weekly_case": np.array(result.weekly_sum),
        "cases_target": np.array(targets["cases"].data)
    }
).plot()