# EXP05: DifEffects

What this notebook does: runs using the different effects model.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
### Initial imports
import logging
import numpy as np
import pandas as pd
import pymc3 as pm
import theano.tensor as T
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import copy
sns.set_style("ticks")

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

from epimodel.pymc3_models import cm_effect
from epimodel.pymc3_models.cm_effect.datapreprocessor import DataPreprocessor

%matplotlib inline

In [None]:
dp = DataPreprocessor(drop_HS=True)
data = dp.preprocess_data("../final_data/data_final.csv")

In [None]:
cm_plot_style = [
#             ("\uf7f2", "tab:red"), # hospital symbol
            ("\uf963", "black"), # mask
            ("\uf492", "mediumblue"), # vial
            ("\uf0c0", "lightgrey"), # ppl
            ("\uf0c0", "grey"), # ppl
            ("\uf0c0", "black"), # ppl
            ("\uf07a", "tab:orange"), # shop 1
            ("\uf07a", "tab:red"), # shop2 
            ("\uf19d", "black"), # school
            ("\uf965", "black") # home
        ]
data.summary_plot(cm_plot_style)

Compared to the usual plot, there are fewer days of schools. 

In [None]:
with cm_effect.models.CMCombined_Final_DifEffects(data, cm_plot_style) as model:
    model.DailyGrowthNoise = 0.2
    model.RegionVariationNoise = 0.1
    model.build_model()

In [None]:
pm.model_to_graphviz(model).render("model-diff-effects")
pm.model_to_graphviz(model)

In [None]:
with model.model:
    model.trace = pm.sample(2000, chains=4, cores=4, target_accept=0.925)
    
pickle.dump(model.trace, open("exp05_diff_effects.pkl", "wb"))

In [None]:
model.plot_effect()

In [None]:
def produce_ranges(trace):
    means = np.mean(trace, axis=0)
    med = np.median(trace, axis=0)
    li = np.percentile(trace, 2.5, axis=0)
    ui = np.percentile(trace, 97.5, axis=0)
    lq = np.percentile(trace, 25, axis=0)
    uq = np.percentile(trace, 75, axis=0)
    return means, med, li, ui, lq, uq

In [None]:
nS, nRs, nCMs = model.trace["AllCMAlpha"].shape

In [None]:
len(x)

In [None]:
len(g)

In [None]:
import copy

rs = copy.deepcopy(data.Rs)

In [None]:
rs.sort(key = lambda x: np.median(np.exp(-model.trace["AllCMAlpha"])[:, data.Rs.index(x), 0], axis=0))
r_index = [data.Rs.index(r) for r in rs]

In [None]:
np.median(100*(1-model.trace["CMReduction"][:, :]), axis=0)

In [None]:
sns.set_style("ticks")
import matplotlib.ticker as mtick

plt.figure(figsize=(12,8), dpi=450)

for cm in range(len(data.CMs)):
    plt.subplot(3, 3, cm+1)
    res = 100*(1-np.exp(-model.trace["AllCMAlpha"]))
    rs.sort(key = lambda x: np.median(res[:, data.Rs.index(x), cm], axis=0))
    r_index = [data.Rs.index(r) for r in rs]

    plt.title(f"{data.CMs[cm]}", fontsize=8)
    plt.xlim([-1, len(r_index)])
    plt.plot([-5, len(r_index)], [0,0], "--", color="tab:red", linewidth=0.5)

    median = 100*(1-np.median(model.trace["CMReduction"][:, cm]))
    plt.plot([-5, len(r_index)], [median, median], "--", color="tab:blue", linewidth=0.5)

    plt.xticks(np.arange(len(r_index)), rs, rotation=90)
    for i, (r, r_i) in enumerate(zip(rs, r_index)):
        mn, med, li, ui, lq, uq = produce_ranges(res[:, r_i, cm])
        
        days_active = np.sum(data.ActiveCMs[r_i, cm, :])
            
        if days_active < 7:
            alpha_mult = 0.25
        else:
            alpha_mult = 1
    
        mn, med, li, ui, lq, uq = produce_ranges(res[:, r_i, cm])
        plt.scatter(i, med, marker="_", s=8, color="k", alpha=1*alpha_mult)
        plt.plot([i, i], [li, ui], color="k", alpha=0.25*alpha_mult, linewidth=1)
        plt.plot([i, i], [lq, uq], color="k", alpha=0.75*alpha_mult, linewidth=1)

    plt.ylabel("Country Specific\nNPI Effectiveness", fontsize=8)
    ax = plt.gca()
    ax.tick_params(axis="both", which="major", labelsize=6)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(8, 3), dpi=450)
cm = 0
plt.subplot(1, 2, 1)
res = 100*(1-np.exp(-model.trace["AllCMAlpha"]))
rs.sort(key = lambda x: np.median(res[:, data.Rs.index(x), cm], axis=0))
r_index = [data.Rs.index(r) for r in rs]

plt.title(f"{data.CMs[cm]}", fontsize=8)
plt.xlim([-1, len(r_index)])
plt.plot([-5, len(r_index)], [0,0], "--", color="tab:red", linewidth=0.5)

median = 100*(1-np.median(model.trace["CMReduction"][:, cm]))
plt.plot([-5, len(r_index)], [median, median], "--", color="tab:blue", linewidth=0.5)

plt.xticks(np.arange(len(r_index)), rs, rotation=90)
for i, (r, r_i) in enumerate(zip(rs, r_index)):

    days_active = np.sum(data.ActiveCMs[r_i, cm, :])

    if days_active < 7:
        alpha_mult = 0.25
    else:
        alpha_mult = 1

    mn, med, li, ui, lq, uq = produce_ranges(res[:, r_i, cm])
    plt.scatter(i, med, marker="_", s=8, color="k", alpha=1*alpha_mult)
    plt.plot([i, i], [li, ui], color="k", alpha=0.25*alpha_mult, linewidth=1)
    plt.plot([i, i], [lq, uq], color="k", alpha=0.75*alpha_mult, linewidth=1)

plt.ylabel("Country Specific\nNPI Effectiveness", fontsize=8)
ax = plt.gca()
ax.tick_params(axis="both", which="major", labelsize=6)
ax.yaxis.set_major_formatter(mtick.PercentFormatter())

plt.subplot(1, 2, 2)

tp = [*rs[:5], *rs[-5:]]
tp.reverse()
mat = np.zeros((len(tp), nCMs))
for i, r in enumerate(tp):
    counts = np.sum(data.ActiveCMs[data.Rs.index(r), :, : ] * data.ActiveCMs[data.Rs.index(r), cm, :] == 1, axis=1)
    counts = counts/counts[cm]
    mat[i, :] = counts
plt.yticks(np.arange(len(tp)), tp, fontsize=8)
im = plt.imshow(100*mat, cmap="viridis", vmax=100, vmin=0)
plt.plot([-1, 10], [4.5, 4.5], color="white", linewidth="2")
plt.xlim([-0.5, 8.5])
plt.xlabel("NPI $i$", fontsize=8)
plt.xticks(np.arange(len(data.CMs)), data.CMs, ha="left", rotation=-20, fontsize=5)
plt.title(f"Frequency[$\phi_i = 1| \phi_{cm} = 1$]", fontsize=8)

plt.tight_layout()
ax = plt.gca()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, format=PercentFormatter())
ax = plt.gca()
ax.tick_params(axis="both", which="major", labelsize=8)

plt.tight_layout()
plt.savefig(f"FigureSAnew.pdf", bbox_inches='tight')

In [None]:
model.trace = pickle.load(open("exp05_diff_effects.pkl", "rb"))

In [None]:

plt.figure(figsize=(8,12), dpi=450)

for cm in [0, 1, 2, 3, 4]:
    plt.subplot(5, 2, cm*2 + 1)
    res = 100*(1-np.exp(-model.trace["AllCMAlpha"]))
    rs.sort(key = lambda x: np.median(res[:, data.Rs.index(x), cm], axis=0))
    r_index = [data.Rs.index(r) for r in rs]

    plt.title(f"{data.CMs[cm]}", fontsize=8)
    plt.xlim([-1, len(r_index)])
    plt.plot([-5, len(r_index)], [0,0], "--", color="tab:red", linewidth=0.5)

    median = 100*(1-np.median(model.trace["CMReduction"][:, cm]))
    plt.plot([-5, len(r_index)], [median, median], "--", color="tab:blue", linewidth=0.5)

    plt.xticks(np.arange(len(r_index)), rs, rotation=90)
    for i, (r, r_i) in enumerate(zip(rs, r_index)):

        days_active = np.sum(data.ActiveCMs[r_i, cm, :])

        if days_active < 7:
            alpha_mult = 0.25
        else:
            alpha_mult = 1

        mn, med, li, ui, lq, uq = produce_ranges(res[:, r_i, cm])
        plt.scatter(i, med, marker="_", s=8, color="k", alpha=1*alpha_mult)
        plt.plot([i, i], [li, ui], color="k", alpha=0.25*alpha_mult, linewidth=1)
        plt.plot([i, i], [lq, uq], color="k", alpha=0.75*alpha_mult, linewidth=1)

    plt.ylabel("Country Specific\nNPI Effectiveness", fontsize=8)
    ax = plt.gca()
    ax.tick_params(axis="both", which="major", labelsize=6)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    plt.subplot(5, 2, cm*2 + 2)

    tp = [*rs[:5], *rs[-5:]]
    tp.reverse()
    mat = np.zeros((len(tp), nCMs))
    for i, r in enumerate(tp):
        counts = np.sum(data.ActiveCMs[data.Rs.index(r), :, : ] * data.ActiveCMs[data.Rs.index(r), cm, :] == 1, axis=1)
        counts = counts/counts[cm]
        mat[i, :] = counts
    plt.yticks(np.arange(len(tp)), tp, fontsize=8)
    im = plt.imshow(100*mat, cmap="viridis", vmax=100, vmin=0)
    plt.plot([-1, 10], [4.5, 4.5], color="white", linewidth="2")
    plt.xlim([-0.5, 8.5])
    
    if i == 4:
        plt.xlabel("NPI $i$", fontsize=8)
        plt.xticks(np.arange(len(data.CMs)), data.CMs, ha="left", rotation=-20, fontsize=6)
    else:
        plt.xticks([])
    plt.title(f"Frequency[$\phi_i = 1| \phi_{cm} = 1$]", fontsize=8)


    ax = plt.gca()
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, format=PercentFormatter())
    ax = plt.gca()
    ax.tick_params(axis="both", which="major", labelsize=8)

plt.tight_layout()
plt.savefig(f"all_one.pdf", bbox_inches='tight')

In [None]:
rs = copy.deepcopy(data.Rs)

sns.set_style("ticks")
import matplotlib.ticker as mtick
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import PercentFormatter
nCMs = 9

plt.figure(figsize=(8,12), dpi=450)

for cm in [5, 6, 7, 8]:
    plt.subplot(5, 2, (cm-5)*2 + 1)
    res = 100*(1-np.exp(-model.trace["AllCMAlpha"]))
    rs.sort(key = lambda x: np.median(res[:, data.Rs.index(x), cm], axis=0))
    r_index = [data.Rs.index(r) for r in rs]

    plt.title(f"{data.CMs[cm]}", fontsize=8)
    plt.xlim([-1, len(r_index)])
    plt.plot([-5, len(r_index)], [0,0], "--", color="tab:red", linewidth=0.5)

    median = 100*(1-np.median(model.trace["CMReduction"][:, cm]))
    plt.plot([-5, len(r_index)], [median, median], "--", color="tab:blue", linewidth=0.5)

    plt.xticks(np.arange(len(r_index)), rs, rotation=90)
    for i, (r, r_i) in enumerate(zip(rs, r_index)):

        days_active = np.sum(data.ActiveCMs[r_i, cm, :])

        if days_active < 7:
            alpha_mult = 0.25
        else:
            alpha_mult = 1

        mn, med, li, ui, lq, uq = produce_ranges(res[:, r_i, cm])
        plt.scatter(i, med, marker="_", s=8, color="k", alpha=1*alpha_mult)
        plt.plot([i, i], [li, ui], color="k", alpha=0.25*alpha_mult, linewidth=1)
        plt.plot([i, i], [lq, uq], color="k", alpha=0.75*alpha_mult, linewidth=1)

    plt.ylabel("Country Specific\nNPI Effectiveness", fontsize=8)
    ax = plt.gca()
    ax.tick_params(axis="both", which="major", labelsize=6)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    plt.subplot(5, 2, (cm-5)*2 + 2)

    tp = [*rs[:5], *rs[-5:]]
    tp.reverse()
    mat = np.zeros((len(tp), nCMs))
    for i, r in enumerate(tp):
        counts = np.sum(data.ActiveCMs[data.Rs.index(r), :, : ] * data.ActiveCMs[data.Rs.index(r), cm, :] == 1, axis=1)
        counts = counts/counts[cm]
        mat[i, :] = counts
    plt.yticks(np.arange(len(tp)), tp, fontsize=8)
    im = plt.imshow(100*mat, cmap="viridis", vmax=100, vmin=0)
    plt.plot([-1, 10], [4.5, 4.5], color="white", linewidth="2")
    plt.xlim([-0.5, 8.5])
    
    if i == 4:
        plt.xlabel("NPI $i$", fontsize=8)
        plt.xticks(np.arange(len(data.CMs)), data.CMs, ha="left", rotation=-20, fontsize=6)
    else:
        plt.xticks([])
    plt.title(f"Frequency[$\phi_i = 1| \phi_{cm} = 1$]", fontsize=8)


    ax = plt.gca()
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax, format=PercentFormatter())
    ax = plt.gca()
    ax.tick_params(axis="both", which="major", labelsize=8)

plt.tight_layout()
plt.savefig(f"all_two.pdf", bbox_inches='tight')