In [4]:
import numpy as np
import pymc3 as pm

import pickle
from datetime import datetime

import seaborn as sns
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt

from epimodel.pymc3_models import cm_effect
from epimodel.pymc3_models.cm_effect.datapreprocessor import DataPreprocessor

sns.set_style("ticks")
%matplotlib inline

In [7]:
dp = DataPreprocessor(drop_HS=True)
data = dp.preprocess_data("./double-entry-data/double_entry_alt_masks.csv", last_day="2020-05-30", schools_unis="single_features")

INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Travel Screen/Quarantine
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Travel Bans
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Public Transport Limited
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Internal Movement Limited
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Public Information Campaigns
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Symptomatic Testing
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Performing Smoothing
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Albania
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Georgia
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Iceland
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Latvia
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing New Zealand
IN

In [8]:
folds = [['FR', 'GR', 'NL', 'BA', 'LV'],
         ['SE', 'DE', 'LT', 'MY', 'BG'],
         ['FI', 'DK', 'CZ', 'RS', 'BE'],
         ['NO', 'SK', 'IL', 'CH', 'ES'],
         ['ZA', 'MX', 'IT', 'IE', 'GE'],
         ['RO', 'PL', 'MA', 'HU', 'SI'],
         ['NZ', 'SG', 'PT', 'HR', 'EE']]

In [9]:
class ResultsObject():
    def __init__(self, indxs, trace):
        self.CMReduction = trace.CMReduction
        self.RegionLogR = trace.RegionLogR[:, indxs]
        self.InfectedCases = trace.InfectedCases[:, indxs, :]
        self.InfectedDeaths = trace.InfectedDeaths[:, indxs, :]
        self.ExpectedCases = trace.ExpectedCases[:, indxs, :]
        self.ExpectedDeaths = trace.ExpectedDeaths[:, indxs, :]
        self.Phi = trace.Phi_1
        
def mask_region(d, region, days=14):
    i = d.Rs.index(region)
    c_s = np.nonzero(np.cumsum(d.NewCases.data[i, :] > 0) == days + 1)[0][0]
    d_s = np.nonzero(np.cumsum(d.NewDeaths.data[i, :] > 0) == days + 1)[0]
    if len(d_s) > 0:
        d_s = d_s[0]
    else:
        d_s = len(d.Ds)

    d.Active.mask[i, c_s:] = True
    d.Confirmed.mask[i, c_s:] = True
    d.Deaths.mask[i, d_s:] = True
    d.NewDeaths.mask[i, d_s:] = True
    d.NewCases.mask[i, c_s:] = True

    return c_s, d_s

In [21]:
data = dp.preprocess_data("./double-entry-data/double_entry_final.csv", last_day="2020-05-30",
                                  schools_unis="whoops")
data.mask_reopenings()

INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Travel Screen/Quarantine
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Travel Bans
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Public Transport Limited
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Internal Movement Limited
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Public Information Campaigns
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Dropping Symptomatic Testing
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Performing Smoothing
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Albania
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Georgia
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Iceland
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing Latvia
INFO:epimodel.pymc3_models.cm_effect.datapreprocessor:Skipping smoothing New Zealand
IN

Masking AL from 2020-04-30 00:00:00+00:00
Masking AD from 2020-05-21 00:00:00+00:00
Masking AT from 2020-05-04 00:00:00+00:00
Masking AT from 2020-05-21 00:00:00+00:00
Masking BE from 2020-05-14 00:00:00+00:00
Masking BA from 2020-05-17 00:00:00+00:00
Masking BG from 2020-05-04 00:00:00+00:00
Masking BG from 2020-05-21 00:00:00+00:00
Masking HR from 2020-04-30 00:00:00+00:00
Masking HR from 2020-05-14 00:00:00+00:00
Masking HR from 2020-05-29 00:00:00+00:00
Masking CZ from 2020-04-27 00:00:00+00:00
Masking CZ from 2020-05-14 00:00:00+00:00
Masking CZ from 2020-05-28 00:00:00+00:00
Masking DK from 2020-04-23 00:00:00+00:00
Masking DK from 2020-05-14 00:00:00+00:00
Masking FI from 2020-05-17 00:00:00+00:00
Masking FR from 2020-05-14 00:00:00+00:00
Masking GE from 2020-05-26 00:00:00+00:00
Masking GE from 2020-05-27 00:00:00+00:00
Masking DE from 2020-04-23 00:00:00+00:00
Masking DE from 2020-05-07 00:00:00+00:00
Masking DE from 2020-05-09 00:00:00+00:00
Masking GR from 2020-05-07 00:00:0

In [71]:
def compute_metric(fold, growth_noise, alpha_noise):
    file = f'../server/diffeff/diffeff_crossval/{alpha_noise}-{growth_noise}-f{fold}.pkl'
    res = pickle.load(open(file, 'rb'))
    
    r_is = [data.Rs.index(r) for r in folds[fold]]
    nd_cases = 0
    nd_deaths = 0
    
    nS, _, _ = res.ExpectedDeaths.shape
    
    cases_ll = 0
    deaths_ll = 0
    
    for i, (rg, r_i) in enumerate(zip(folds[fold], r_is)):
        cases_start, deaths_start = mask_region(data, rg)
        total_cms = data.ActiveCMs[r_i, :, :]
        diff_cms = np.zeros_like(total_cms)
        diff_cms[:, 1:] = total_cms[:, 1:] - total_cms[:, :-1]
        ds = np.nonzero(np.any(diff_cms < 0, axis=0))
        if len(ds[0]) > 0:
            cases_end = ds[0][0]+3
            deaths_end = ds[0][0]+12
        else:
            cases_end = len(data.Ds)-1
            deaths_end = cases_end
        
        nd_cases_rg = cases_end - cases_start + 1
        nd_deaths_rg = deaths_end - deaths_start + 1
        
        nd_cases += nd_cases_rg
        
        # now actually compute likelihoods
        if nd_deaths_rg > 5:
            nd_deaths += nd_deaths_rg
            expected_deaths = res.ExpectedDeaths[:, i, deaths_start:(deaths_end+1)]
            death_dist = pm.NegativeBinomial.dist(mu=expected_deaths, alpha=np.repeat(res.Phi.reshape((nS, 1)), nd_deaths_rg, axis=1))     
            deaths_ll += np.sum(np.log(np.mean(np.exp(death_dist.logp(data.NewDeaths.data[r_i, deaths_start:(deaths_end+1)].reshape((1, nd_deaths_rg))).eval()))))
        
        expected_cases = res.ExpectedDeaths[:, i, cases_start:(cases_end+1)]
        cases_dist = pm.NegativeBinomial.dist(mu=expected_cases, alpha=np.repeat(res.Phi.reshape((nS, 1)), nd_cases_rg, axis=1))     
        cases_ll += np.sum(np.log(np.mean(np.exp(cases_dist.logp(data.NewCases.data[r_i, cases_start:(cases_end+1)].reshape((1, nd_cases_rg))).eval()))))
        
    return ((cases_ll, nd_cases, cases_ll/nd_cases), (deaths_ll, nd_deaths, deaths_ll/nd_deaths))

def aggregate_metric(growth_noise, alpha_noise):
    nd_cases = 0
    nd_deaths = 0
    cases_ll = 0
    deaths_ll = 0
    
    for fold in range(len(folds)):
        cases_res, deaths_res = compute_metric(fold, growth_noise, alpha_noise)
        
        nd_cases += cases_res[1]
        cases_ll += cases_res[0]
        nd_deaths += deaths_res[1]
        deaths_ll += deaths_res[0]
    
    return cases_ll/nd_cases, deaths_ll/nd_deaths

In [76]:
alpha_noise = [0.05, 0.075, 0.1, 0.125, 0.15]
growth_noise = [0.05, 0.1, 0.15, 0.175, 0.2, 0.225, 0.25]

cases_grid = np.zeros((len(alpha_noise), len(growth_noise)))
deaths_grid = np.zeros((len(alpha_noise), len(growth_noise)))

for a_i, alpha in enumerate(alpha_noise):
    for g_i, gnoise in enumerate(growth_noise):
        res = aggregate_metric(gnoise, alpha)
        cases_grid[a_i, g_i] = res[0]
        deaths_grid[a_i, g_i] = res[1]
        print(f'({alpha}, {gnoise}) cases {res[0]:.3f} deaths {res[1]:.3f}')

(0.05, 0.05) cases -0.358 deaths -0.095
(0.05, 0.1) cases -0.337 deaths -0.089
(0.05, 0.15) cases -0.300 deaths -0.085
(0.05, 0.175) cases -0.284 deaths -0.084
(0.05, 0.2) cases -0.266 deaths -0.084
(0.05, 0.225) cases -0.263 deaths -0.083
(0.05, 0.25) cases -0.260 deaths -0.083
(0.075, 0.05) cases -0.353 deaths -0.095
(0.075, 0.1) cases -0.322 deaths -0.090
(0.075, 0.15) cases -0.293 deaths -0.086
(0.075, 0.175) cases -0.277 deaths -0.085
(0.075, 0.2) cases -0.247 deaths -0.084
(0.075, 0.225) cases -0.227 deaths -0.083
(0.075, 0.25) cases -0.230 deaths -0.083
(0.1, 0.05) cases -0.342 deaths -0.094
(0.1, 0.1) cases -0.300 deaths -0.090
(0.1, 0.15) cases -0.278 deaths -0.086
(0.1, 0.175) cases -0.263 deaths -0.085
(0.1, 0.2) cases -0.245 deaths -0.085
(0.1, 0.225) cases -0.220 deaths -0.084
(0.1, 0.25) cases -0.221 deaths -0.084
(0.125, 0.05) cases -0.328 deaths -0.094
(0.125, 0.1) cases -0.287 deaths -0.090
(0.125, 0.15) cases -0.263 deaths -0.087
(0.125, 0.175) cases -0.252 deaths -0.