This notebook is used to do a main model run

In [1]:
from epimodel import EpidemiologicalParameters, preprocess_data
from epimodel.numpyro_models.models import ComplexDifferentEffectsModel

import numpy as np
import jax
import jax.numpy as jnp
import numpyro
numpyro.set_host_device_count(2)



Set Theano Environmental Variables for Parallelisation


In [2]:
data = preprocess_data('../merged_data/data_final_nov.csv', last_day='2020-05-30', smoothing=1)
data.mask_reopenings(print_out = False)

Dropping NPI Mask Wearing
Dropping NPI Travel Screen/Quarantine
Dropping NPI Travel Bans
Dropping NPI Public Transport Limited
Dropping NPI Internal Movement Limited
Dropping NPI Public Information Campaigns
Dropping NPI Symptomatic Testing
Masking invalid values


In [3]:
ep = EpidemiologicalParameters()
bd = ep.get_model_build_dict()
print(bd)
model = ComplexDifferentEffectsModel(data)
run_model = model.build_model(**bd)

{'gi_mean_mean': 5.06, 'gi_mean_sd': 0.3265, 'gi_sd_mean': 2.11, 'gi_sd_sd': 0.5, 'deaths_delay_mean_mean': 21.819649695284962, 'deaths_delay_mean_sd': 1.0056755718977664, 'deaths_delay_disp_mean': 14.26238141720708, 'deaths_delay_disp_sd': 5.177442947725441, 'cases_delay_mean_mean': 10.92830227448381, 'cases_delay_mean_sd': 0.9387435298564465, 'cases_delay_disp_mean': 5.406593726647138, 'cases_delay_disp_sd': 0.2689502951493133}


In [4]:
data.ActiveCMs.shape

(41, 8, 130)

In [None]:
kernel = numpyro.infer.NUTS(run_model)
print(f"Ketnel: {kernel}")
mcmc = numpyro.infer.MCMC(kernel, 500, 2000, num_chains=2, progress_bar=True)
print(f"MCMC: {mcmc}")
mcmc.run(jax.random.PRNGKey(1), jnp.array(data.ActiveCMs), jnp.array(data.NewCases.data), jnp.array(model.all_observed_active))
mcmc.print_summary()

#with model:
#    model.trace = pm.sample(2000, tune=500, cores=4, chains=4, max_treedepth=12, target_accept=0.96)

In [None]:
# save results in a pickle file
import pickle
pickle.dump(model.trace, open('traces/final_final_nov.pkl', 'wb'))