In [1]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/renewal.git

In [2]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from plotly.express.colors import qualitative as qual_colours
import numpy as np

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel, ModelResult
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, new_get_spaghetti, new_new_get_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.targets import StandardTarget

In [3]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
window_len = 32
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
analysis_start = datetime(2021, 3, 1)
analysis_end = datetime(2021, 11, 1)
init_start = analysis_start - timedelta(window_len)
init_end = analysis_start - timedelta(1)
select_data = mys_data.loc[analysis_start: analysis_end]
init_data = mys_data.loc[init_start: init_end]

In [4]:
renew_model = RenewalModel(33e6, analysis_start, analysis_end, proc_update_freq, CosineMultiCurve(), GammaDens(), window_len, init_data, GammaDens())

In [5]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Uniform(6.5, 10.5),
    "gen_sd": dist.Uniform(3.0, 4.6),
    "cdr": dist.Beta(4.0, 10.0),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.Uniform(8.0, 12.0),
    "report_sd": dist.Uniform(3.0, 6.0),
}

In [6]:
targets = {
    "cases": StandardTarget(select_data, 0.1)
}

In [7]:
calib = StandardCalib(renew_model, priors, targets)
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=calib.custom_init(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=100, num_warmup=100)
mcmc.run(random.PRNGKey(1))

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=200)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [9]:
spaghetti = new_new_get_spaghetti(calib, sample_params)
key_outputs = ["cases", "suscept", "r_t", "process"]
quantiles = get_quant_df_from_spaghetti(spaghetti, quantiles=[0.05, 0.5, 0.95])

KeyboardInterrupt: 

In [12]:
spaghetti

Unnamed: 0_level_0,cases,cases,cases,cases,cases,cases,cases,cases,cases,cases,...,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum,weekly_sum
Unnamed: 0_level_1,"(0, 0)","(0, 1)","(0, 10)","(0, 11)","(0, 12)","(0, 13)","(0, 14)","(0, 15)","(0, 16)","(0, 17)",...,"(1, 90)","(1, 91)","(1, 92)","(1, 93)","(1, 94)","(1, 95)","(1, 96)","(1, 97)","(1, 98)","(1, 99)"
2021-03-01,2625.731046240152,2680.2529892250286,2627.8848495906245,2628.911832406573,2634.071273725947,2637.4642164354627,2657.900262906155,2673.1433159617723,2623.3072255309394,2616.6535019477337,...,19460.130678672682,19480.839310373845,19490.203779843152,19456.818857220875,19451.075674412445,19461.92314680215,19477.95225988544,19481.19907986251,19446.72951800111,19453.152028567594
2021-03-02,2541.414538152596,2611.4642772902657,2540.2306961368486,2542.8278077809814,2545.0696765590833,2549.202428924856,2574.987029754776,2593.5296347469157,2527.190580798841,2507.575855313227,...,19094.25480248436,19125.730227904703,19142.519275989802,19085.75537808962,19071.382005897252,19085.15586197211,19115.778236718837,19116.529995736884,19074.6989626636,19084.693497333607
2021-03-03,2449.613601302112,2528.2493074680874,2442.738743085011,2446.9450215401325,2443.880426337446,2448.812211042541,2480.349725215498,2501.0004177441724,2421.366126797023,2386.9316900555323,...,18640.273763420337,18694.98707400532,18727.649133207684,18621.6775840423,18591.281255333048,18611.334310167178,18671.90389470597,18668.141598329745,18607.295744723004,18624.972840897557
2021-03-04,2355.619734208212,2434.7063579249448,2341.816885358825,2347.4354021887025,2337.5252982057054,2343.7703630404812,2380.7037397449353,2401.8819470559224,2313.2729949596455,2264.62334355646,...,18105.47677930366,18184.622941217596,18227.175293835346,18073.28442523034,18021.267354817173,18048.074740073993,18142.825727024647,18126.80553403248,18058.261912923143,18083.61483543155
2021-03-05,2262.188106112294,2335.5095682135716,2241.0024746381823,2247.811172394953,2230.2511631063317,2238.7924199966596,2280.6617974539367,2301.0624241589635,2206.904156037646,2146.1531645215223,...,17519.40953786844,17621.733759617386,17661.65054552118,17470.262353542028,17391.862615945178,17425.726477578533,17556.92418783831,17518.696427985313,17457.497025851662,17490.09953884023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2021-10-28,5366.503696160878,5591.235793760594,5221.710140134238,5785.739668868161,5243.333093582565,5407.803452277147,5630.3279254734325,5416.671192003907,4916.80877556538,5228.922460877915,...,39988.18530383832,40255.8832654069,37411.16317515443,41389.819471794726,39500.00837118207,42562.972013116734,41138.28964323597,41560.8961062381,40441.43726299862,39887.318899731574
2021-10-29,5301.341979323107,5511.13141081908,5149.409372307198,5742.025345134253,5179.019139859709,5359.6629776110285,5569.7658908123985,5381.634111688219,4811.957652205626,5131.344595094202,...,39317.03283542248,39989.95724770576,36512.02892008186,40661.81711186895,38903.57025267612,41668.40904612466,40371.81249501872,40903.57755047542,40119.74239635025,39029.4340668507
2021-10-30,5238.270348851442,5433.761878340325,5085.468649199665,5702.936939206924,5121.539609201425,5320.183148377578,5513.625548853081,5355.949243300889,4716.615343651846,5034.511923656161,...,38648.875260512585,39817.996086052895,35689.14757946405,39920.556679125,38323.53454323577,40723.039499948776,39623.55149655166,40293.62186321305,39824.951718704695,38170.675619761634
2021-10-31,5176.894308524715,5358.809988089924,5028.596157685878,5667.678870069373,5069.756912756201,5287.150521073798,5461.178653088469,5338.2103965742,4629.119073481431,4938.5310620441805,...,37982.6825756488,39736.354187277684,34939.561909282966,39166.56359315305,37758.309935038036,39732.24360891667,38892.57494608343,39728.411037715225,39553.23794098796,37310.99486580546


In [10]:
quantiles=[0.5]
outputs = set(spaghetti.columns.get_level_values(0))
column_names = pd.MultiIndex.from_product([outputs, quantiles])
quantiles_df = pd.DataFrame(index=spaghetti.index, columns=column_names)
for col in outputs:
    quantiles_df[col] = spaghetti[col].quantile(quantiles, axis=1).T

KeyboardInterrupt: 

In [None]:
plot_uncertainty_patches(quantiles, select_data, qual_colours.Plotly, outputs=key_outputs).update_layout(showlegend=False)

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, list(priors.keys()), priors);

In [None]:
parameters = {i: j for i, j in [k for k in sample_params.iterrows()][0][1].items() if "dispersion" not in i}
result = renew_model.renewal_func(**parameters)

In [None]:
pd.DataFrame(
    {
        "cases": np.array(result.cases),
        "weekly_case": np.array(result.weekly_sum),
        "cases_target": np.array(targets["cases"].data)
    }
).plot()