In [1]:
import pandas as pd
import plotly.express as px

import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from jax import jit, random
import arviz as az
from estival.sampling import tools as esamp
from plotly.express.colors import qualitative as qual_colours
from IPython.display import Markdown

from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches
from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.calibration import StandardCalib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Grab data on 2014-2016 Ebola outbreak - this analysis we will focus on the epidemic in Seirra Leone. This data was don

ebola_data = pd.read_csv(r'C:\Users\ehug0006\emu\wpro-working\data\ebola_2014_2016_clean.csv', index_col="Date", parse_dates=True)

In [3]:
fig = px.line(ebola_data, x=ebola_data.index, y='Cumulative no. of confirmed, probable and suspected cases', color='Country')

fig.show()

In [4]:
ebola_data

Unnamed: 0_level_0,Country,"Cumulative no. of confirmed, probable and suspected cases","Cumulative no. of confirmed, probable and suspected deaths"
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2014-08-29,Guinea,648.0,430.0
2014-08-29,Nigeria,19.0,7.0
2014-08-29,Sierra Leone,1026.0,422.0
2014-08-29,Liberia,1378.0,694.0
2014-09-05,Sierra Leone,1261.0,491.0
...,...,...,...
2016-03-23,Liberia,10666.0,4806.0
2016-03-23,Italy,1.0,0.0
2016-03-23,Liberia,5.0,4.0
2016-03-23,Nigeria,20.0,8.0


In [5]:
case_data = ebola_data[ebola_data['Country']=='Sierra Leone']['Cumulative no. of confirmed, probable and suspected cases'].diff().rolling(14).mean().dropna()

In [6]:
case_data

Date
2014-10-17    170.285714
2014-10-22    174.642857
2014-10-25    181.071429
2014-10-29    272.214286
2014-10-31    265.571429
                 ...    
2015-12-17      0.000000
2015-12-22      0.000000
2015-12-23      0.000000
2015-12-29      0.000000
2016-03-23      0.000000
Name: Cumulative no. of confirmed, probable and suspected cases, Length: 245, dtype: float64

In [7]:
fig = px.line(case_data, x=case_data.index, y=case_data)

fig.show()

In [20]:
# Specify fixed parameters and get calibration data
run_in = 10
proc_update_freq = 4
pop = 7.1e6
analysis_start = ebola_data.index[0]
analysis_end = datetime(2015, 5, 1)
select_data = case_data.loc[analysis_start: analysis_end]

In [21]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, fitter, GammaDens(), fitter, 50)

In [22]:
calib = StandardCalib(renew_model, select_data)

In [23]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Uniform(10.0, 14.0),
    "gen_sd": dist.Uniform(3.0, 7.0),
    "cdr": dist.Beta(10.0, 4.0),
    "seed": dist.Uniform(0.4, 1.5),
}

In [24]:
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",), params=priors)


There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.

sample: 100%|██████████| 2000/2000 [01:56<00:00, 17.13it/s, 255 steps of size 1.67e-02. acc. prob=0.94]
sample: 100%|██████████| 2000/2000 [01:53<00:00, 17.57it/s, 255 steps of size 1.57e-02. acc. prob=0.92]


In [25]:
idata = az.from_numpyro(mcmc)

In [26]:
burn_in = 10
n_samples = 100
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [27]:
def get_full_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, seed)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [28]:
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)