This notebook is used to do a main model run

# Complex Different Effects Model

**What's different about this model compared to v1?**

Switched from a centered to a non-centered parameterisation (classic probprog trick). 

In [None]:
from epimodel import EpidemiologicalParameters, DefaultModel, preprocess_data
from epimodel.pymc3_models.models import ComplexDifferentEffectsModel

import numpy as np
import pymc3 as pm

In [None]:
data = preprocess_data('../merged_data/double_entry_final.csv', last_day='2020-05-30', smoothing=1)
data.mask_reopenings(print_out = False)

In [None]:
ep = EpidemiologicalParameters()

In [None]:
bd = ep.get_model_build_dict()

In [None]:
bd

In [None]:
with ComplexDifferentEffectsModel(data) as model:
    model.build_model(**bd)

In [None]:
pm.model_to_graphviz(model)

# Note: i've increased target_accept - might improve sampling too. 

In [None]:
with model:
    model.trace = pm.sample(2000, tune=500, cores=4, chains=4, max_treedepth=12, target_accept=0.95)

# Is the inference stable?

In [None]:
import arviz as az

In [None]:
rhat = az.rhat(model.trace)
ess = az.ess(model.trace)

In [None]:
model.trace.varnames

In [None]:
def collate(stat):
    stat_all = []
    stat_nums = []
    for var in ["CMReduction", "GI_mean", "GI_sd", "GrowthCasesNoise", "GrowthDeathsNoise", "CasesDelayMean", "CasesDelayDisp",
                "InitialSizeDeaths_log", "InitialSizeCases_log", "DeathsDelayMean", "DeathsDelayDisp", "HyperRVar", "PsiCases",
                "PsiDeaths", "InfectedDeaths", "InfectedCases", "ExpectedDeaths", "ExpectedCases", "AllCMAlphaNoise", "AllCMAlpha"]:
        if stat[str(var)].size>1:
            stat_all.append(stat[str(var)].to_dataframe().to_numpy().flatten())
        else:
            stat_nums.append(float(stat[str(var)]))
    stat_all = np.concatenate(np.array(stat_all))
    stat_all = np.concatenate([stat_all, stat_nums])
#     stat_all[stat_all > 100] = 1
    return stat_all

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(7, 3), dpi=300)
plt.subplot(121)
plt.hist(collate(rhat), bins=40, color='tab:purple')
plt.title("$\hat{R}$", fontsize=12)

plt.subplot(122)
plt.hist(collate(ess)/8000, bins=40, color='tab:purple')
plt.xlim([0, 3.5])
plt.title("Relative ESS", fontsize=12)

$\hat{R}$ looks okay, ESS looks somewhat troubling .... inspect in more detail

In [None]:
rhat

# $\hat{R}$ issues

I suspected that the sampling would be stable. Its much better than expected, but we have slightly high $\hat{R}$ for the individual $\alpha_{i,c}$, as well as the scales. 

The important thing is not having too many divergences. No divergences is good - I think we'd just want to run this for more samples. 

In [None]:
ess

# ESS issues!

Similar issue with the effective sample size - note that we only have 23 *effective* samples for the School Closure NPI. I have a few thoughts on improving this though - the inference is better than I originally expected. 

This means that sampler is having some difficulty sampling these parameters, and the trace has high autocorrelation. 

# Are the traces bad?

In [None]:
az.plot_trace(model.trace, var_names=['CMAlphaScales'])

Indeed, inspecting the above traces, we see very high autocorrelation (look at the right plots). The density estimates for the different traces also looks pretty different. 

If anybody see these traces, they'll tell you that the are probably garbage. 

I have some ideas as to how to fix this though. 

# Preliminary results analysis

In [None]:
model.plot_effect()

Results similar to that expected, not too big of a difference

## What alpha noise do we learn?

In [None]:
import matplotlib.pyplot as plt 

plt.figure(figsize=(4, 3), dpi=300)

plt.scatter(np.percentile(model.trace.CMAlphaScales, 50, axis=0), -np.arange(9))

for i in range(9):
    plt.plot([np.percentile(model.trace.CMAlphaScales[:, i], 2.5), np.percentile(model.trace.CMAlphaScales[:, i], 97.5)], [-i, -i], color='tab:blue', alpha=0.25)
    plt.plot([np.percentile(model.trace.CMAlphaScales[:, i], 25), np.percentile(model.trace.CMAlphaScales[:, i], 75)], [-i, -i], color='tab:blue', alpha=0.5)

plt.plot([0.1, 0.1], [-9, 5], 'k--')

plt.ylim([-8.5, 0.5])
plt.yticks(-np.arange(9), data.CMs)
plt.xlabel('$\sigma_{\\alpha, i}$')
plt.title('NPI Region Variability')

# What does this actually look like, uncertainty wise?

In [None]:
for i in range(9):
    plt.figure(figsize=(6.5, 3), dpi=300)
    for r in range(len(data.Rs)):
        perred = 100*(1-np.exp(-model.trace.AllCMAlpha[:, r, i]))
        plt.plot([r, r], [np.percentile(perred, 2.5, axis=0), np.percentile(perred, 97.5, axis=0)], color='k', alpha=0.25)
        plt.plot([r, r], [np.percentile(perred, 25, axis=0), np.percentile(perred, 75, axis=0)], color='k', alpha=0.5)
        plt.scatter(r, np.median(perred), color='k', marker='_')
    
    plt.ylabel('Reduction in $R_t$ (%)')
    y_min, y_max = plt.ylim()
    plt.plot([len(data.Rs)-0.5, len(data.Rs)-0.5], [y_min-2, y_max+2], '--', color='tab:red')
    plt.ylim([y_min, y_max])
    
    overall_red = 100*(1-model.trace.CMReduction[:, i])
    plt.scatter(len(data.Rs), np.median(overall_red), color='tab:purple')
    plt.plot([len(data.Rs), len(data.Rs)], [np.percentile(overall_red, 2.5), np.percentile(overall_red, 97.5)],  color='tab:purple', alpha=0.25)
    plt.plot([len(data.Rs), len(data.Rs)], [np.percentile(overall_red, 25), np.percentile(overall_red, 75)], color='tab:purple', alpha=0.5)
    plt.title(data.CMs[i])
    plt.xticks(np.arange(len(data.Rs)+1), [*data.Rs, 'Agg'], fontsize=6)

In [None]:
import pickle

pickle.dump(model.trace, open('traces/complexdiffeffv2.pkl', 'wb'))