# Countermeasures model v2

Changes:
* Fixed convolution by creating a tested library version
* Cleared up import code, removed obsolete mask, added text cells
* Improve plotting & writeouts
* Renamed from "basic" to v1 (this is hardly a simple model)
* Move most code to library (WIP)

Model v2:
* Add country unreliability parameter, used as scale multiplier in growth rate and measurement noise.

## Import & initialization

In [4]:
%load_ext autoreload
%autoreload 2

import logging
import numpy as np
import pandas as pd
import pymc3 as pm
import theano.tensor as T
import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from epimodel.pymc3_models import cm_effect

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Loading data and parameter settings

In [16]:
Regions = ['AT', 'BA', 'BE', 'BG', 'CH', 'CZ', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GB', 'GR', 'HR', 'HU', 'IE', 'IS', 'IT', 'LT', 'NL', 'PL', 'RO', 'RS', 'RU', 'SE', 'SI', 'SK']
Features = ['Gatherings limited to', 'Business suspended',
       'Schools and universities closed', 'General curfew',
       'Healthcare specialisation', 'Minor distancing and hygiene measures',
       'Phone line', 'Asymptomatic contact isolation']
#'Mask wearing' - has NaNs 
data = cm_effect.Loader('2020-02-29', '2020-04-04', Regions, Features)

[2020-04-10 11:27:34,247] INFO(epimodel.regions): Name index has 7 potential conflicts: ['american samoa', 'georgia', 'guam', 'northern mariana islands', 'puerto rico', 'united states minor outlying islands', 'virgin islands (u.s.)']


## Create, run and plot stability of model V2

In [None]:
with cm_effect.CMModelV2(data, delay_mean=7.0) as model2:
    model2.build()
model2.run(1000)
_ = model2.plot_traces()

CM delay: mean 6.959979582552529, len 16, cut at 10


[2020-04-10 11:27:40,557] INFO(pymc3): Auto-assigning NUTS sampler...
[2020-04-10 11:27:40,558] INFO(pymc3): Initializing NUTS using adapt_diag...


CMReduction_log__            11.07
BaseGrowthRate_log__         -1.75
RegionGrowthRate_log__        7.98
RegionScaleMult_log__       -49.05
RealGrowth_log__           1394.72
InitialSize_log__           -90.20
Observed_missing              0.00
Observed                 -34453.14
Name: Log-probability of test_point, dtype: float64


[2020-04-10 11:27:47,147] INFO(pymc3): Multiprocess sampling (2 chains in 2 jobs)
[2020-04-10 11:27:47,147] INFO(pymc3): NUTS: [Observed_missing, InitialSize, RealGrowth, RegionScaleMult, RegionGrowthRate, BaseGrowthRate, CMReduction]
Sampling 2 chains, 0 divergences:  23%|██▎       | 690/3000 [05:14<40:39,  1.06s/draws]

## Create, run and plot stability of model V2g (gauss prior)

In [None]:
with cm_effect.CMModelV2g(data, delay_mean=6.0) as model2g:
    model2g.build()
model2g.run(1000)
_ = model2g.plot_traces()

## Plot inferred countermeasure effect

Effects are multiplicative (e.g. for a countermeasure that is a strenghtening of another, the inferred strenght is the additional multiplier).

The countermeasure strength is the multiplicative effect at feature=1.0.

In [8]:
_ = model2.plot_CMReduction()
model2.print_CMReduction()

AssertionError: 

In [None]:
import plotly
from plotly import graph_objects as go
import plotly.express as px
import datetime

def plot_line_CIs(fig, x, ys, name, color, quantiles=(0.05, 0.25), opacities=(0.1, 0.15)):
    x=list(x)
    fig.add_trace(go.Scatter(
        x=x,
        y=list(ys.mean(axis=0)),
        name=name, legendgroup=name, line_color=color
    ))    
    for q, o in zip(quantiles, opacities):
        ylo = list(np.quantile(ys, q, axis=0))
        yhi = list(np.quantile(ys, 1.0-q, axis=0))
        fig.add_trace(go.Scatter(
            x=x + x[::-1],
            y=ylo+yhi[::-1],
            fill='toself',
            fillcolor=color,
            opacity=o,
            line_color='rgba(255,255,255,0)',
            showlegend=False,
            name=name, legendgroup=name,
        ))    
        
fig = go.FigureWidget()
for i, c in enumerate(Cs):
    d = trace['DailyGrowth'][:,i,:]
    color = (px.colors.qualitative.Dark24 * 10)[i]
    plot_line_CIs(fig, Ds[CMDelayCut:], d, c, color)
datestr = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
plotly.io.write_html(fig, f'{datestr}_growth_estimated_v2.html', include_plotlyjs='cdn')
fig