In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
### Initial imports
import numpy as np
import pickle
import pandas as pd
import pymc3 as pm
import theano.tensor as T
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("ticks")

import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

from epimodel.pymc3_models import cm_effect
from epimodel.pymc3_models.cm_effect.datapreprocessor import DataPreprocessor
from epimodel.pymc3_models.cm_effect.models import add_cms_to_plot

%matplotlib inline

In [None]:
cm_plot_style = [
#             ("\uf7f2", "tab:red"), # hospital symbol
            ("\uf963", "black"), # mask
            ("\uf0c0", "darkgrey"), # ppl
            ("\uf0c0", "dimgrey"), # ppl
            ("\uf0c0", "black"), # ppl
            ("\uf07a", "tab:orange"), # shop 1
            ("\uf07a", "tab:red"), # shop2
            ("\uf549", "black"), # school
            ("\uf19d", "black"), # university
            ("\uf965", "black") # home
]

def mask_region(d, region, days=14):
    i = d.Rs.index(region)
    c_s = np.nonzero(np.cumsum(d.NewCases.data[i, :] > 0)==days+1)[0][0]
    d_s = np.nonzero(np.cumsum(d.NewDeaths.data[i, :] > 0)==days+1)[0]
    if len(d_s) > 0:
        d_s = d_s[0]
    else:
        d_s = len(d.Ds)

    d.Active.mask[i,c_s:] = True
    d.Confirmed.mask[i,c_s:] = True
    d.Deaths.mask[i,d_s:] = True
    d.NewDeaths.mask[i,d_s:] = True
    d.NewCases.mask[i,c_s:] = True

In [None]:
dp = DataPreprocessor(drop_HS=True)

In [None]:
def month_to_str(x):
    if x == 1:
        return "JAN"
    if x == 2:
        return "FEB"
    if x == 3:
        return "MAR"
    if x == 4:
        return "APR"
    if x == 5:
        return "MAY"

In [None]:
class ResultsObject():
    def __init__(self, indx, trace):
        self.CMReduction = trace.CMReduction
        self.RegionR = trace.RegionR[:, indx]
        self.InfectedCases = trace.InfectedCases[:, indx, :]
        self.InfectedDeaths = trace.InfectedDeaths[:, indx, :]
        self.ExpectedCases = trace.ExpectedCases[:, indx, :]
        self.ExpectedDeaths = trace.ExpectedDeaths[:, indx, :]
        
def produce_CIs(data):
    means = np.median(data, axis=0)
    li = np.percentile(data, 2.5, axis=0)
    ui = np.percentile(data, 97.5, axis=0)
    err = np.array([means - li, ui - means])
    return means, li, ui, err

In [None]:
data = dp.preprocess_data("double_entry_final.csv", last_day="2020-05-30", schools_unis="lol")
data.mask_reopenings()

start_d_i = 30 

for r_i, r in enumerate(data.Rs):
    p_i = r_i % 12
    
    f_i = int(r_i/12)
    
    if r_i == 0:
        plt.figure(figsize=(10, 14), dpi=300)
    elif r_i % 12 == 0:
        plt.tight_layout()
        ax.legend(shadow=True, fancybox=True, loc = "upper center", bbox_to_anchor = (-0.85, -0.25), fontsize=8, ncol=4)
        plt.savefig(f"FigureHoldouts{f_i}.pdf", bbox_inches='tight')
        plt.figure(figsize=(10, 13), dpi=300)
    elif r_i + 1 == len(data.Rs):
        plt.tight_layout()
        ax.legend(shadow=True, fancybox=True, loc = "upper center", bbox_to_anchor = (1.9, -0.25), fontsize=8, ncol=4)
        plt.savefig(f"FigureHoldouts{f_i+1}.pdf", bbox_inches='tight')
        
        
    
    plt.subplot(4, 3, p_i + 1)
    ax = plt.gca()
    
    if len(np.nonzero(data.NewCases.mask[r_i, :])[0]) > 0:
        end_d_i = np.nonzero(data.NewCases.mask[r_i, :])[0][0]-3
    else:
        end_d_i = len(data.Ds)
    mask_region(data, r)
    
    res = pickle.load(open(f"../../server/ho_results_final4/{r}.pkl", "rb"))
    
    means_d, lu_id, up_id, err_d = produce_CIs(
                res.InfectedDeaths
            )
    
    means_c, lu_ic, up_ic, err_c = produce_CIs(
                res.InfectedCases
            )

    ec = res.ExpectedDeaths
    nS, nDs = ec.shape
    
    try:
        dist = pm.NegativeBinomial.dist(mu=ec, alpha=60)
        ec_output = dist.random()
    except ValueError:
        ec_output = 0 * ec
        means_d = means_d * 0
        lu_id = lu_id * 0
        up_id = up_id * 0

    means_expected_deaths, lu_ed, up_ed, err_expected_deaths = produce_CIs(
        ec_output
    )
    
    eco = res.ExpectedCases
    nS, nDs = eco.shape
    
    try:
        dist = pm.NegativeBinomial.dist(mu=eco, alpha=60)
        eco_output = dist.random()
    except ValueError:
        eco_output = 10**-10 * ec

    means_expected_cases, lu_ec, up_ec, err_expected_cases = produce_CIs(
        eco_output
    )
    

    days = data.Ds
    days_x = np.arange(len(days))
    deaths = data.NewDeaths[r_i, :]
    cases = data.NewCases[r_i, :]

    ax = plt.gca()
    plt.plot(
        days_x,
        means_d,
        label="Daily Infections - Later Fatal",
        zorder=1,
        color="tab:orange",
        alpha=0.25
    )

    plt.fill_between(
        days_x, lu_id, up_id, alpha=0.15, color="tab:orange", linewidth=0
    )

    plt.plot(
        days_x,
        means_expected_deaths,
        label="Predicted Daily Deaths",
        zorder=2,
        color="tab:red"
    )

    plt.fill_between(
        days_x, lu_ed, up_ed, alpha=0.25, color="tab:red", linewidth=0
    )

    plt.scatter(
        days_x,
        deaths,
        label="Recorded Daily Deaths (Smoothed)",
        marker="o",
        s=10,
        color="tab:red",
        alpha=0.9,
        zorder=3,
    )

    plt.scatter(
        days_x,
        deaths.data,
        label="Heldout Daily Deaths (Cases)",
        marker="o",
        s=12,
        edgecolor="tab:red",
        facecolor="white",
        linewidth=1,
        alpha=0.9,
        zorder=2,
    )
    
    plt.plot(
        days_x,
        means_c,
        label="Daily Infections - Later Reported",
        zorder=1,
        color="darkgreen",
        alpha=0.25
    )

    plt.fill_between(
        days_x, lu_ic, up_ic, alpha=0.15, color="darkgreen", linewidth=0
    )

    plt.plot(
        days_x,
        means_expected_cases,
        label="Predicted Daily Confirmed Cases",
        zorder=2,
        color="tab:blue"
    )

    plt.fill_between(
        days_x, lu_ec, up_ec, alpha=0.25, color="tab:blue", linewidth=0
    )

    plt.scatter(
        days_x,
        cases,
        label="Recorded Daily Confirmed Cases (Smoothed)",
        marker="o",
        s=10,
        color="tab:blue",
        alpha=0.9,
        zorder=3,
    )

    plt.scatter(
        days_x,
        cases.data,
        label="Heldout Daily Confirmed Cases (Cases)",
        marker="o",
        s=12,
        edgecolor="tab:blue",
        facecolor="white",
        linewidth=1,
        alpha=0.9,
        zorder=2,
    )

    ax.set_yscale("log")
    plt.ylim([10 ** 0, 10 ** 6])
    locs = np.arange(start_d_i, end_d_i, 14)
    xlabels = [f"{data.Ds[ts].day}-{month_to_str(data.Ds[ts].month)}" for ts in locs]
    plt.xticks(locs, xlabels, rotation=-30, ha="left")
    plt.xlim((start_d_i, end_d_i))
    
    add_cms_to_plot(ax, data.ActiveCMs, r_i, start_d_i, end_d_i, data.Ds, cm_plot_style)
    plt.title(data.RNames[r][0], fontsize=12)