In [None]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours
from pathlib import Path
import math
import plotly.express as px
import plotly.graph_objects as go
from matplotlib import pyplot as plt

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.utils import get_adjust_idata_index, adjust_summary_cols

from plotting import plot_main
from utils import load_target_data, load_mobility_data, load_vaccination_data, load_variant_prevalence_data

## Setup country setting and load data

In [None]:
# Set country
country = 'Philippines'

In [None]:
# load data
target_data = load_target_data(country)
mobility_data = load_mobility_data(country)
vaccination_data = load_vaccination_data(country)
variant_data = load_variant_prevalence_data(country)

## Specify renewal model and parameters

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 21
init_time = 50
data = target_data
pop = 116e6
analysis_start = datetime(2021, 5, 1) #change from 1st May
analysis_end = datetime(2022, 4, 30)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = data.loc[analysis_start: analysis_end]
init_data = data.resample("D").asfreq().interpolate().loc[init_start: init_end] / 7.0

In [None]:
# Define renewal model and proc update fitting method
proc_fitter = CosineMultiCurve()
renew_model = RenewalModel(pop, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), init_time, init_data, GammaDens())

In [None]:
# Define priors for calibrated parameter
priors = {
    "gen_mean": dist.TruncatedNormal(7.3, 0.5, low=1.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=1.0),
    "cdr": dist.Beta(16, 40), 
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(8, 0.5, low=1.0),
    "report_sd": dist.TruncatedNormal(3, 0.5, low=1.0),
    "prop_immune": dist.Beta(32, 40) 
}

## Model calibration

In [None]:
# Run calibration - fitting to weekly cases
calib = StandardCalib(renew_model, priors, select_data, indicator='weekly_sum')
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=4, num_samples=4000, num_warmup=500)
mcmc.run(random.PRNGKey(123))

## Wrangle model outputs

In [None]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=800)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd, prop_immune):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd, prop_immune)

full_wrap = jit(get_full_result)
panel_subtitles = ["weekly_sum"] + PANEL_SUBTITLES[1:]
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap, outputs=panel_subtitles)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles=[0.05, 0.5, 0.95], outputs=panel_subtitles)

## Visualise model outputs and comparison against data

In [None]:
# Restrict range of mobility and vax data to analysis timeframe
mobility_data = mobility_data.loc[analysis_start:analysis_end]
vaccination_data = vaccination_data.loc[analysis_start:analysis_end]

In [None]:
fig = plot_main(quantiles_df, select_data, mobility_data, vaccination_data).update_layout(showlegend=False)

# Add interventions and variant prevalence to plots
number = 2
for i in range(number):
    
    delta = variant_data[variant_data['variant'] == 'Delta']
    delta_week = delta['week'].to_string(index=False)
    fig.add_vline(x=datetime.strptime(delta_week, "%Y-%m-%d").timestamp() * 1000, 
              annotation_text="Delta >50%", annotation_position="bottom right", row=1, col=i+1, line_dash="dash")
    
    omicron = variant_data[variant_data['variant'] == 'Omicron_BA1_2']
    omicron_week = omicron['week'].to_string(index=False)
    fig.add_vline(x=datetime.strptime(omicron_week, "%Y-%m-%d").timestamp() * 1000, 
              annotation_text="Omicron >50%", annotation_position="bottom right", row=1, col=i+1, line_dash="dash")

fig['layout']['xaxis5']['title']='Date'
fig['layout']['xaxis6']['title']='Date'
fig['layout']['yaxis']['title']='Weekly reported cases'
fig['layout']['yaxis2']['title']=''
fig['layout']['yaxis3']['title']='Total persons susceptible'
fig['layout']['yaxis4']['title']='' 
fig['layout']['yaxis5']['title']='% change from baseline'
fig['layout']['yaxis6']['title']='% total population' 

fig.show()

## Exploring model outputs and calculating attack rates

In [None]:
# Quick code for exploring quantiles
print(quantiles_df.loc['2021-12-20':'2022-1-10'])

In [None]:
quantiles_df.loc['2021-12-1':'2022-2-20']['R'][0.95].max()

In [None]:
# quick code for calculating final attack rate
suscepts = quantiles_df['susceptibles']
suscepts['AR_05'] = 1 - suscepts[0.05]/pop
suscepts['AR_5'] = 1 - suscepts[0.50]/pop
suscepts['AR_95'] = 1- suscepts[0.95]/pop

In [None]:
suscepts.tail(5)

## Calibration results

In [None]:
# summary table of calibration
az.summary(idata, hdi_prob=0.95)

In [None]:
# prior-posterior comparison plot
plot_post_prior_comparison(idata, [p for p in priors.keys() if p !="rt_init"],priors);
#plt.savefig('phil_posterior.png')

In [None]:
# trace and posterior plot
az.plot_trace(idata)
plt.tight_layout()
#plt.savefig('philippines_trace.png')

## Model descriptions and table for supplement

In [None]:
Markdown(renew_model.get_description())

In [None]:
Markdown(calib.get_description())

## Export model, target and comparison data to excel for submission

In [None]:
with pd.ExcelWriter('Philippines_results.xlsx') as writer:  
    quantiles_df.to_excel(writer, sheet_name='Model_outputs')
    select_data.to_excel(writer, sheet_name='WHO_weekly_cases')
    mobility_data.to_excel(writer, sheet_name='Mobility_data')
    vaccination_data.to_excel(writer, sheet_name='Vaccination_data')