In [None]:
# make sure you are using the pymc_env environment
# train final hetGPy model
import pandas as pd
import numpy as np
from hetgpy.hetGP import hetGP
import sys
sys.path.append('../../')
from fig1_calibration import create_sim as cs
from scipy.stats import qmc

bounds = cs.define_pars(which='bounds',use_safegraph=True)
bounds['tn'][0] = 5.0
keys = list(bounds.keys())
# get model outputs

sims = pd.read_csv('../data/sims_combined-wave004.csv')

s0   = sims.loc[sims['rand_seed']==sims['rand_seed'].min()]
tmax = s0.shape[0]
tvec = np.arange(0,tmax)
sims['t'] = np.concatenate([tvec for _ in range(sims['rand_seed'].nunique())])



tkeep = np.arange(0,tmax,7) # weekly output -- also will include infections
sims = sims.loc[sims['t'].isin(tkeep)]

output_keys = ['death']
models = {}
X = qmc.scale(
    sims[keys].values,
    l_bounds = [bounds[k][0] for k in keys],
    u_bounds = [bounds[k][1] for k in keys],
    reverse=True
    )
t = ((sims.t - sims.t.min()) / (sims.t.max()-sims.t.min())).values.reshape(-1,1)
X = np.hstack([X,t])
X.shape, sims.shape

In [None]:
m_and_sds = {}
for outkey in output_keys:
    vals = sims[outkey].values
    m_and_sds[outkey] = vals.mean(), vals.std()
    model = hetGP()
    model.mleHetGP(
        X = X,
        Z = vals,
        covtype="Matern5_2",
        lower = [0.1 for i in range(X.shape[1])],
        upper = [10 for i in range(X.shape[1])],
        maxit = 1000,
        settings = {'checkHom':True}
    )
    models[outkey] = model
models

In [None]:
models['death'].plot()

In [None]:
import pickle
with open('../models/hetGPy-last-round-death.pkl','wb') as stream:
    pickle.dump(models,stream)

In [None]:
observed = sims.loc[sims['rand_seed']==sims['rand_seed'].min()][[f"{col}_data" for col in output_keys]]
for key in output_keys:
    col = f"{key}_data"
    observed[col] = (observed[col])
if 'infectious' in output_keys:
    nan_idxs = np.isnan(observed['infectious_data'].values).nonzero()[0]
    observed['infectious_data'] = observed['infectious_data'].fillna(0)
else:
    nan_idxs = []
#observed['infectious_data'] = observed['infectious_data'].fillna(0)
observed

In [None]:
import matplotlib.pyplot as plt
rand = np.random.default_rng(42)
NI  = pd.read_csv('../hm_waves/NI_pars_wave003.csv')
df_NI_scaled = pd.DataFrame(qmc.scale(
    NI[keys].values,
    l_bounds = [bounds[k][0] for k in keys],
    u_bounds = [bounds[k][1] for k in keys],
    reverse=True
    ),
    columns=keys)
tvec = ((tkeep - tkeep.min()) / (tkeep.max() - tkeep.min())).reshape(-1,1)
def sim(rng,a,b,c,d,size=None):
    # pars
    beta   = a
    bc_lf  = b
    tn     = c
    bc_wc1 = d
    
    # assemble input data
    x = np.array([beta,bc_lf,tn,bc_wc1]).reshape(-1)
    X = np.vstack([x for _ in range(len(tvec))])
    X = np.hstack([X,tvec])
    # predict
    out = {}
    for key in models.keys():
        p = models[key].predict(X)
        m = p['mean']
        sd = np.sqrt(p['sd2'])
        # sample
        out[key] = m
        out[key] = rand.normal(loc=m,scale=sd)
        if key=='infectious':
            out[key][nan_idxs] = 0
    return pd.DataFrame(out)


# samples from simulation
fig, ax = plt.subplots(figsize=(6,4))
n_sample = 50
samples = df_NI_scaled.sample(n=n_sample)[keys].values
for x in range(n_sample):
    p = samples[x,:]
    s = sim(rng = None, a = p[0], b = p[1], c = p[2], d = p[3], size=None)
    for i, col in enumerate(s.columns):
        observed.reset_index()[f"{col}_data"].plot(ax=ax,color='black',alpha=0.6)
        s[col].plot(ax=ax,color='blue')

In [None]:
import pymc as pm
rand = np.random.default_rng(42)
def setup_inputs(key,kw):
    kw = kw.copy()
    kw.update(
        {'mu':df_NI_scaled[key].mean(),
         'sigma':df_NI_scaled[key].std(),
         'lower': 0.0,
         'upper':1.0})
    return kw

with pm.Model() as model_abc:
    kw = {}
    a = pm.TruncatedNormal("beta",**setup_inputs('beta',kw))
    b = pm.TruncatedNormal("bc_lf",**setup_inputs('bc_lf',kw))
    c = pm.TruncatedNormal("tn",**setup_inputs('tn',kw))
    d = pm.TruncatedNormal("bc_wc1",**setup_inputs('bc_wc1',kw))

    simulator = pm.Simulator("sim3", fn=sim, params=(a,b,c,d),epsilon=5,observed=observed)

    idata_abc = pm.sample_smc(draws=2000,cores=5,chains=5)

In [None]:
# save
idata_abc.posterior.to_netcdf('../posterior/calibrated-death.nc')
idata_abc.posterior.to_dataframe().reset_index()[keys].to_csv('../posterior/calibrated-death.csv',index=False)
with open('../posterior/calibrated-death-full.pkl','wb') as stream:
    pickle.dump(idata_abc,stream)