In [None]:
from typing import List
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'
from datetime import datetime
from estival import priors as esp
import pymc as pm
import arviz as az

from emu_renewal.renew import RenewalModel, TruncRenewalModel
from emu_renewal.calibrate import get_wrapped_ll, use_model

In [None]:
run_in = 30
n_process_periods = 12
raw_data = pd.read_csv('https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv', index_col=0)['MYS']
raw_data.index = pd.to_datetime(raw_data.index)
mys_data = raw_data.loc[datetime(2021, 3, 1): datetime(2021, 11, 1)].reset_index()['MYS']
mys_data.index += run_in
n_times = len(mys_data) + run_in
calib_kwargs = {'pop': 33e6, 'n_times': n_times, 'run_in': run_in, 'targets': mys_data}

In [None]:
# Define parameter ranges
scalar_req = {
    'Generation time mean (days)' : {'init': 5.0, 'lower': 0.1, 'upper': 14.0},
    'Generation time standard deviation (days)': {'init': 5.0, 'lower': 2.5, 'upper': 8.0},
    'Case detection proportion': {'init': 0.06, 'lower': 0.04, 'upper': 0.2},
    'Log starting seed rate': {'init': np.log(1e4), 'lower': np.log(5e3), 'upper': np.log(2e4)},
}
params_df = pd.DataFrame(scalar_req).transpose()
proc_req = {'init': 0.0, 'lower': -2.0, 'upper': 2.0}

In [None]:
def get_obj_func(model):
    def calib_func_ll(parameters: List[float], pop: int, n_times: int, run_in: int, targets: dict) -> float:
        gen_time_mean, gen_time_sd, cdr, seed, *process = parameters
        incidence = model.func(gen_time_mean, gen_time_sd, process, seed).incidence
        return 0.0 - sum([(incidence[t] * cdr - d) ** 2 for t, d in targets.items()])

    def obj_func_ll(gen_time_mean, gen_time_sd, cdr_param, seed, proc_params):
        return calib_func_ll([gen_time_mean, gen_time_sd, cdr_param, seed] + list(proc_params), **calib_kwargs)

    return obj_func_ll

In [None]:
priors = [esp.UniformPrior(k, (v['lower'], v['upper'])) for k, v in params_df.iterrows()]
priors.append(esp.UniformPrior('Log non-mechanistic process values', (proc_req['lower'], proc_req['upper']), size=n_process_periods))
renewal_model = RenewalModel(calib_kwargs['pop'], n_times, run_in, n_process_periods)
obj_func = get_obj_func(renewal_model)
n_draws = 100

with pm.Model() as pmm:
    variables = use_model(priors, obj_func)
    idata = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=n_draws)

In [None]:
az.summary(idata)

In [None]:
mean_posterior_params = mpp = np.array(az.summary(idata)['mean'])

In [None]:
incidence = renewal_model.func(mpp[0], mpp[1], mpp[4:], mpp[3]).incidence * mpp[2]

In [None]:
import random

burn_in = 10
n_samples = 10
idata_df = idata.to_dataframe()
idata_df = idata_df.loc[idata_df.index >= burn_in]
sampled_draws = sorted(random.sample(range(burn_in, n_draws), n_samples))
sampled_df = idata_df[[i in sampled_draws for i in idata_df['draw']]]

spaghetti = pd.DataFrame()
for _, row in sampled_df.iterrows():
    proc = [row[6]] + list(row[9: 18]) + list(row[7: 9])
    incidence = renewal_model.func(row[2], row[3], proc, row[5]).incidence * row[4]
    spaghetti[_] = incidence

In [None]:
spaghetti['targets'] = calib_kwargs['targets']

In [None]:
spaghetti.plot()