In [None]:
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import numpy as np
np.random.seed(123456)

import sys; sys.path.append("..")
from epimodel import EpidemiologicalParameters
from epimodel.pymc3_models.mask_models import RandomWalkMobilityModel, MandateMobilityModel
from epimodel.preprocessing.preprocess_mask_data import Preprocess_masks
import epimodel.viz.region_plot as rp
import epimodel.viz.prior_posterior as pp
import epimodel.viz.yougov as yg
import epimodel.viz.pred_cases as pc
import epimodel.viz.result_plot as rep
import epimodel.viz.mandate_wearing as mw
import epimodel.viz.empirical_wearing as ew

import calendar 
import theano.tensor as T
import theano.tensor.signal.conv as C
import theano

import pymc3 as pm
import pandas as pd
import copy
import re
import pickle
import datetime
from datetime import timedelta
import argparse
from pathlib import Path
import matplotlib.dates as mdates
import matplotlib.ticker as mtick

import pickle 
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import arviz as az
import json

sns.set(style="ticks", font='DejaVu Serif')
PNAS_WIDTH_INCHES = 3.4252

In [None]:
MODEL = "cases" #args.model
MASKS = "wearing" #args.masks
W_PAR = "exp" #args.w_par

n_mandates = 2

MOBI = 'include' # args.mob
US = True
SMOOTH = False
GATHERINGS = 3 #args.gatherings if args.gatherings else 3
# MASKING = True # Always true

if MODEL == "both":
    TUNING = 2500
else:
    TUNING = 2000

Ds = pd.date_range("2020-05-01", "2020-09-21", freq="D")

In [None]:
# prep data object

path = f"../data/modelling_set/master_data_mob_{MOBI}_us_{US}_m_w.csv"
masks_object = Preprocess_masks(path=path)

masks_object.featurize(gatherings=GATHERINGS, masks=MASKS, smooth=SMOOTH, mobility=MOBI, n_mandates = 2)
masks_object.make_preprocessed_object()
data = masks_object.data

In [None]:
# just mobility
path = "../data/modelling_set/master_data_mob_include_us_True_m_w.csv"
mobility_data = pd.read_csv(path)
mobility_data = mobility_data.set_index(["country", "date"])



In [None]:
def load_oxcgrt(use_us=True):
    OXCGRT_PATH = "../data/raw/OxCGRT_latest.csv"
    oxcgrt = pd.read_csv(OXCGRT_PATH, parse_dates=["Date"], low_memory=False)
    # Drop regional data
    nat = oxcgrt[oxcgrt.Jurisdiction == "NAT_TOTAL"]

    # Add US states
    if use_us:
        states = oxcgrt[
            (oxcgrt.CountryName == "United States")
            & (oxcgrt.Jurisdiction == "STATE_TOTAL")
        ]
        # Drop GEO to prevent name collision
        nat = nat[nat.CountryName != "Georgia"]
        states.CountryName = states.RegionName
        states = states.replace("Georgia", "Georgia-US")
        nat = pd.concat([nat, states])

    i = list(nat.columns).index("Date")
    nat.columns = list(nat.columns[:i]) + ["date"] + list(nat.columns[i + 1 :])

    return nat[nat.date.isin(Ds)]


maxes = {
    "C1": 3, 
    "C2": 3, 
    "C3": 2, 
    "C4": 4, 
    "C5": 2, 
    "C6": 3, 
    "C7": 2, 
    "C8": 4, 
    "H1": 2
}

def subindex(df, c) :
    x = df[c] - 0.5 * (1 - df[c[:2] + "_Flag"] )
    return (x / maxes[c[:2]]) * 100


# now with no weight on regional
def nat_subindex(df, c) :
    x = df[c] * df[c[:2] + "_Flag"] 
    return (x / maxes[c[:2]]) * 100

# https://github.com/OxCGRT/covid-policy-tracker/blob/master/documentation/index_methodology.md
def stringency(data, national=True):
    if national: 
        sub = nat_subindex
    else :
        sub = subindex
    
    df = data.copy()
    new_stringency_codes = ["C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "H1"]
    nonflags = [c for c in data.columns if "Flag" not in c]
    new_stringency_cols = [c for c in nonflags if c[:2] in new_stringency_codes]
    
    for c in new_stringency_cols:
        if c[:2] != "C8":
            df[c] = sub(df, c)
        else :
            df[c] = df[c] / maxes[c[:2]] * 100
    
    return df[new_stringency_cols].mean(axis=1).round(2)


oxcgrt = load_oxcgrt()
oxcgrt = oxcgrt[oxcgrt.date.isin(Ds)]
oxcgrt["StringencyIndexNational"] = stringency(oxcgrt)

In [None]:
# https://rt.live/

def load_rts() :
    third_party_rts_us = pd.read_csv("../data/raw/us_rt_live.csv", parse_dates=["date"])
    third_party_rts_us
    us_state_abbrev = {    'Alabama': 'AL',    'Alaska': 'AK',    'American Samoa': 'AS',    'Arizona': 'AZ',    'Arkansas': 'AR',    'California': 'CA',    'Colorado': 'CO',    'Connecticut': 'CT',    'Delaware': 'DE',    'District of Columbia': 'DC',    'Florida': 'FL',    'Georgia': 'GA',    'Guam': 'GU',    'Hawaii': 'HI',    'Idaho': 'ID',    'Illinois': 'IL',    'Indiana': 'IN',    'Iowa': 'IA',    'Kansas': 'KS',    'Kentucky': 'KY',    'Louisiana': 'LA',    'Maine': 'ME',    'Maryland': 'MD',    'Massachusetts': 'MA',    'Michigan': 'MI',    'Minnesota': 'MN',    'Mississippi': 'MS',    'Missouri': 'MO',    'Montana': 'MT',    'Nebraska': 'NE',    'Nevada': 'NV',    'New Hampshire': 'NH',    'New Jersey': 'NJ',    'New Mexico': 'NM',    'New York': 'NY',    'North Carolina': 'NC',    'North Dakota': 'ND',    'Northern Mariana Islands':'MP',    'Ohio': 'OH',    'Oklahoma': 'OK',    'Oregon': 'OR',    'Pennsylvania': 'PA',    'Puerto Rico': 'PR',    'Rhode Island': 'RI',    'South Carolina': 'SC',    'South Dakota': 'SD',    'Tennessee': 'TN',    'Texas': 'TX',    'Utah': 'UT',    'Vermont': 'VT',    'Virgin Islands': 'VI',    'Virginia': 'VA',    'Washington': 'WA',    'West Virginia': 'WV',    'Wisconsin': 'WI',    'Wyoming': 'WY'}
    us_state_abbrev = {value:key for key, value in us_state_abbrev.items()}
    third_party_rts_us.region = third_party_rts_us.region.replace(us_state_abbrev)
    third_party_rts_us.region = third_party_rts_us.region.replace({"Georgia": "Georgia-US"})
    cols = ['date', 'region', 'mean', 'lower_80', 'upper_80']
    third_party_rts_us = third_party_rts_us[cols]

    # http://epidemicforecasting.org/country-rt-estimates
    epifor_rts = pd.read_csv("../data/raw/r_estimates_epifor.csv", parse_dates=["Date"])
    epifor_rts = epifor_rts[epifor_rts.EnoughData == 1]
    epifor_rts = epifor_rts[epifor_rts.Date.isin(Ds)]
    epifor_rts.Date = epifor_rts.Date.dt.date #pd.to_datetime(epifor_rts.Date, utc=True)

    def recode(epifor_rts) :
        with open("../data/raw/3166.json") as f:
            codes = json.load(f)
            codes = codes['3166-1']

        a2 = [c.get("alpha_2") for c in codes]
        names = [c.get("name") for c in codes]
        map_ = {c: n for c, n in zip(a2, names)}
        epifor_rts["region"] = epifor_rts.Code
        epifor_rts.region = epifor_rts.Code.replace(map_)
        epifor_rts.region = epifor_rts.region.replace({"Tanzania, United Republic of": "Tanzania"})
        epifor_rts.region = epifor_rts.region.replace({"Viet Nam": "Vietnam"})
        epifor_rts.region = epifor_rts.region.replace({"Korea, Republic of": "South Korea"})

        return epifor_rts

    epifor_rts = recode(epifor_rts)
    epifor_rts.columns = ["code", "date", "mean", "std", "enough", "region"]

    z80 = 0.842
    epifor_rts["lower_80"] = epifor_rts["mean"] - z80 * epifor_rts["std"]
    epifor_rts["upper_80"] = epifor_rts["mean"] + z80 * epifor_rts["std"]
    #epifor_rts = epifor_rts[cols]
    third_party_rts = pd.concat([epifor_rts, third_party_rts_us])
    
    # Correct for absence of case delay in Epifor
    third_party_rts["mean_minus10"] = third_party_rts["mean"].shift(-10)
    third_party_rts.date = pd.to_datetime(third_party_rts.date)
    
    return third_party_rts


third_party_rts = load_rts()

# What's the empirical Rt?

In [None]:
relevant = third_party_rts[third_party_rts.date.isin(Ds)]
relevant = relevant[relevant.region.isin(data.Rs)]
relevant["mean_minus10"].mean(), relevant["mean_minus10"].std()

In [None]:
print("How many regions?:", len(data.Rs))
print("How many NPIs?:", len(data.CMs))
print("How many days?:", len(data.Ds))

# Load pickle

In [None]:
mandate_pkl_5pct = "mandate_2and3_cases_countries_92_05-28-16:44.pkl"
mandate_pkl_0 = "mandate_2and3_cases_countries_92_06-04-16:37.pkl"
q2_pkl_1000_250 = "zeroed_wearing_hyper_Rs_wearing_log_quadratic_2_cases_countries_92_05-25-18:37.pkl"
exp_pkl = "wearing_exp_cases_countries_92_05-31-02:56.pkl" # 1000 + 700 
ll_pkl = "wearing_log_linear_cases_countries_92_06-05-16:46.pkl"


def load_pickle(p) :
    path = 'pickles/' + p
    with open(path, 'rb') as buff:
        trace = pickle.load(buff)

    colfile = path[:-4] + "_cols"
    
    cf = Path(colfile)
    if cf.is_file():
        with open(colfile, "r") as f:
            npi_cols = f.read().split(", ")
    else :
        print("cols missing")
        npi_cols = []
    
    return trace, npi_cols


m_trace, _ = load_pickle(mandate_pkl_0)
exp_trace, npi_cols = load_pickle(exp_pkl)
#q2_trace, npi_cols = load_pickle(q2_pkl_1000_250)
#ll_trace, _ = load_pickle(ll_pkl)

print("Pickle has", exp_trace.RegionR.shape[1], "regions")




varnames = [c for c in exp_trace.varnames if c not in ["Psi_log__", "Infected_log", "HyperRVar_log__", "GrowthNoiseScale_log__", "r_walk_noise_scale_log__"]]
varnames

In [None]:
trace = exp_trace

In [None]:
#ns = npi_cols.copy()

v = ["CMReduction"]
if "WearingReduction" in m_trace.varnames :
    v += ["WearingReduction"]
if "MandateReduction" in m_trace.varnames :
    v += ["MandateReduction"]
if "MobilityReduction" in m_trace.varnames :
    v += ["MobilityReduction"]
    #ns.remove("avg_mobility_no_parks_no_residential")
    #ns += ["MobilityReduction"]


s = pm.summary(m_trace, var_names=v, hdi_prob=0.95)
#s.index = ns
s

# Significance as AUC

In [None]:
def p(samples) :
    return 1 - np.sum(samples < 1) / len(samples)

p(exp_trace.WearingReduction), p(m_trace.MandateReduction)

In [None]:
def plot_posterior_renamed() :
    plots = pm.plots.plot_posterior(trace, var_names=["WearingReduction", "MobilityReduction"], hdi_prob=0.95)
    n = npi_cols
#     for i, p in enumerate(plots):
#         p.set_title(n[i])

    plt.show()
    
# plot_posterior_renamed()

In [None]:
import arviz as az
ess = az.ess(exp_trace)
rhat = az.rhat(exp_trace)

In [None]:
vs_minus_logs = [v for v in varnames if v not in ["r_walk_noise_scale_log__", "HyperRMean_lowerbound__"]]


def collate(stat, varnames):
    stat_all = []
    stat_nums = []
    for var in varnames :
        if stat[str(var)].size > 1:
            stat_all.append(stat[str(var)].to_dataframe().to_numpy().flatten())
        else:
            stat_nums.append(float(stat[str(var)]))
    stat_all = np.concatenate(np.array(stat_all))
    stat_all = np.concatenate([stat_all, stat_nums])
    # stat_all[stat_all > 100] = 1
    return stat_all

def diagnostics(tr) :
    for r in tr.varnames :
        rhat = pm.rhat(tr)[r]
        print(f'Rhat({r}) = {rhat}')

cols = sns.cubehelix_palette(3, start=0.2, light=0.6, dark=0.1, rot=0.2)

PNAS_WIDTH_INCHES = 3.4252
plt.figure(figsize=(PNAS_WIDTH_INCHES * 1.5, 2), dpi=400)
plt.subplot(121)
plt.hist(collate(rhat, vs_minus_logs), bins=40, color=cols[0])
plt.title("$\hat{R}$", fontsize=8)
ylabels = ['{:,.0f}'.format(x) + 'k' for x in plt.gca().axes.get_yticks()/1000]
plt.gca().set_yticklabels(ylabels)
plt.ylabel("Number of parameters", fontsize=8)

def get_total_samples(t) :
    l = str(t).replace("<MultiTrace: ", "").split(",")
    return int(l[0].split()[0]) * int(l[1].split()[0])


plt.subplot(122)
plt.hist(collate(ess, vs_minus_logs), bins=40, color=cols[0]) # / samples
#plt.xlim() # [0,2] 
plt.gca().axes.get_yaxis().set_visible(False)
plt.title("ESS", fontsize=8) 
plt.tight_layout()

plt.savefig("../outputs/mcmc_wearing.pdf", bbox_inches="tight")


# Prior / posterior

In [None]:
import imp
imp.reload(pp)
pp.plot_all_pps(exp_trace, m_trace)
plt.tight_layout()
plt.savefig(f"../outputs/pp_grid.pdf", bbox_inches="tight")

# Fits and holdouts

In [None]:
imp.reload(pc)

# for region in data.Rs :
#     pc.epicurve_plot(data, oxcgrt, trace, region)


    
pred_rs = ["Australia", "United Kingdom", "Nigeria", "Singapore"] 
f, (rowax, colax) = plt.subplots(2, 2, figsize=(10,6), dpi=500, sharex=True)
pc.epicurve_plot(data, oxcgrt, Ds, exp_trace, pred_rs[0], rowax[0], leg=True)
pc.epicurve_plot(data, oxcgrt,  Ds, exp_trace, pred_rs[1], rowax[1])
pc.epicurve_plot(data, oxcgrt,  Ds, exp_trace, pred_rs[2], colax[0])
pc.epicurve_plot(data, oxcgrt,  Ds, exp_trace, pred_rs[3], colax[1])


plt.savefig(f"../outputs/pred_curves_4.pdf", bbox_inches="tight")

# Key Panels

In [None]:
import imp
imp.reload(rp)

pred_rs = ["India", "Sweden", "United Kingdom", "South Africa"]
for r in pred_rs: #data.Rs :
    rp.reprod_plot(exp_trace, data, mobility_data, oxcgrt, third_party_rts, r, start_d_i=0) 
    #plt.show()
    plt.savefig(f"../outputs/region_plots_{r}.pdf", bbox_inches="tight")

# Mob vs Wearing

In [None]:
from scipy.stats import pearsonr

df = masks_object.df

# corrs = []
# for c in mobility_data.reset_index().country.unique() :
#     if c not in df.reset_index().country.unique() :
#         continue
    
#     mobc = mobility_data.loc[c]
#     cdf = df.loc[c]
#     j = cdf[["percent_mc"]].join(mobc.avg_mobility_no_parks_no_residential)
#     corr = np.corrcoef(j.avg_mobility_no_parks_no_residential, j.percent_mc)
#     if corr[0][1] != np.nan :
#         corrs.append(corr[0][1])

# corrs = [c for c in corrs if not np.isnan(c)]
#np.mean(corrs)

# print("rho (overall)", pearsonr(df.avg_mobility_no_parks_no_residential, df["percent_mc"]))

# Mandate vs wearing

In [None]:
from scipy.stats import pearsonr
df = mobility_data

print("rho (binary overall)", pearsonr(df.percent_mc, df["H6_Facial Coverings"])[0])

## Simpson's paradox
# corrs = []
# for c in df.reset_index().country.unique() :
#     cdf = df.loc[c]
#     corr = pearsonr(cdf.percent_mc, cdf["H6_Facial Coverings"])[0]
#     if corr != np.nan :
#         corrs.append(corr)
# corrs = [c for c in corrs if not np.isnan(c)]
# print("rho (binary by country)", np.mean(corrs))#, np.std(corrs))

# or original 4-level H6:
df2 = df.reset_index()
df2["date"] = pd.to_datetime(df2["date"])
df2 = df2[["country", "date", "percent_mc"]]
df2 = df2.set_index(["country", "date"])
ourox = oxcgrt[oxcgrt.CountryName.isin(data.Rs)][["CountryName", "date", "H6_Facial Coverings"]]
ourox.columns = ["country", "date", "H6_Facial Coverings"]
ourox = ourox.set_index(["country", "date"])
ourox = ourox.join(df2)

print("Original 4-level rho", pearsonr(ourox.percent_mc, ourox["H6_Facial Coverings"])[0] )

In [None]:
imp.reload(mw)
import epimodel.viz.mandate_wearing as mw


def post_mandate_rise(Rs, df):
    ranges = []
    d = mw.get_centred_summary(Rs, df)
    for c in d.country.unique() :
        last = d[(d.country == c) & (d.day > 23)]
        first = d[(d.country == c) & (d.day < 7)]
        if len(last) :
            last = last.percent_mc.mean()
            first = first.percent_mc.mean()
            ranges.append( last - first )
    
    return ranges

rises = post_mandate_rise(data.Rs, df)
piles = np.percentile(rises, [2.5, 50, 97.5])


def exp_reduction_vector(a, x) :
    reductions = 1 - np.exp((-1.0) * a * x)
    return reductions

a = exp_trace.Wearing_Alpha.mean()
exp_reduction_vector(a, piles)

In [None]:
# df = df.reset_index()
# df[df.country == "Bangladesh"][["percent_mc"]].plot()
# plt.ylim(0,1)
summaries = mw.get_centred_summary(data.Rs, df)
len(summaries.country.unique())

# Fig .1

In [None]:
df = mobility_data

imp.reload(mw)


#mw.mandate_barplot(df)
#mw.mandate_distplot(df)
#mw.messy_centred_mandate_plot(data.Rs, df)
#mw.original_stringency_plot(ourox)
#mw.centred_mandate_plot(data.Rs, df)

mw.country_centred_mandate_plot(data.Rs, df)



plt.savefig("../outputs/mw_panels.pdf", bbox_inches="tight")

    

In [None]:
with open("wearing_reduction_samples.txt", "r") as f :
    wred = f.read()
    wred = np.array(wred.split("\n")[:-1]).astype(np.float64)

with open("mandate_reduction_samples.txt", "r") as f :
    mred = f.read()
    mred = np.array(mred.split("\n")[:-1]).astype(np.float64)

In [None]:
# import epimodel.viz.average_trends as at
# import imp
# imp.reload(at)
# df = mobility_data

# import matplotlib.image as mpimg




# fig = plt.figure(figsize=(15,7), dpi=500)
# gs = fig.add_gridspec(3,6)
# ax00 = fig.add_subplot(gs[0, 0])
# ax01 = fig.add_subplot(gs[0, 1])
# ax10 = fig.add_subplot(gs[1, 0])
# ax11 = fig.add_subplot(gs[1, 1])
# ax3 = fig.add_subplot(gs[1, 2:])
# ax4 = fig.add_subplot(gs[0, 2:4])
# ax5 = fig.add_subplot(gs[0, 4:])

# img = mpimg.imread('../outputs/val-1.png')
# ax3.imshow(img, aspect='auto')
# ax3.axis('off')
# ax3.set_title("G", loc="left", fontweight="bold")


# Ds = pd.to_datetime(Ds)
# at.plot_avg_daily_new(df, Ds, ax00)
# at.plot_mandates(df, Ds, ax01)
# at.plot_mob(df, Ds, ax11)
# at.plot_wearing(df, Ds, ax10)


# def main_result_posteriors(mred, wred, ax):
#     sns.kdeplot(wred, label="wearing", shade=True, ax=ax)    
#     sns.kdeplot(mred, label="mandate", color="green", shade=True, ax=ax)
#     ax.axvline(x=0, color="black", linestyle="--")
#     ax.set_xlabel("Inferred % reduction in R", fontsize=16)
#     ax.set_xlim(-20, 60)
#     ax.axes.get_yaxis().set_visible(False)
#     ax.legend(fontsize=12, frameon=False)

    
# def exp_reduction(a, x):
#     reductions = 1 - np.exp((-1.0) * a * x)
#     return reductions.mean()


# def get_median_reduction(a, df):
#     obs_ = []
#     r = exp_reduction

#     for c in df.reset_index().country.unique():
#         cdf = df.loc[c]
#         median_ = cdf.percent_mc.median()
#         med_reduction_r = r(a, median_)
#         actual = med_reduction_r
#         obs_.append(actual * 100)

#     return obs_
    
# def plot_median_wearing_effect(df, alpha, ax):
#     obs_ = get_median_reduction(alpha, df)
#     print(np.percentile(obs_, [2.5, 50, 97.5]))
#     sns.distplot(obs_, kde=True, hist=False, kde_kws={"shade": True}, ax=ax)

#     ax.set_xlabel("% R reduction (by regional wearing level)", fontsize=16)
#     ax.yaxis.set_ticks([])
#     ax.set_xlim(0, 30)

#     med = np.median(obs_)
#     print(med)
#     ax.axvline(x=med, color="black", linestyle="--", label="median")

#     ax.legend(fontsize=12, frameon=False, loc="upper left")


# main_result_posteriors(mred, wred, ax=ax4)
# ax4.set_title("E", loc='left', fontweight="bold")

# df = mobility_data
# alpha = exp_trace.Wearing_Alpha
# plot_median_wearing_effect(df, alpha, ax=ax5)
# ax5.set_title("F", loc='left', fontweight="bold")

# plt.tight_layout(pad=1.2)
# #plt.subplots_adjust(top=0.9)#, right=1)

# plt.savefig("../outputs/fig3_placeholder.pdf", bbox_inches="tight")

# Fig 1: Posteriors

In [None]:
import imp
imp.reload(rep)

PNAS_WIDTH_INCHES = 3.4252
fig, ax = plt.subplots(figsize=(PNAS_WIDTH_INCHES,PNAS_WIDTH_INCHES * 1.5), dpi=400)
ax = plt.subplot(2,1,1)
rep.main_result_posteriors(m=m_trace, w=exp_trace, ax=ax)
plt.tight_layout()
plt.subplots_adjust(top=1.3)
plt.title("A", loc='left', fontweight="bold")

df = mobility_data
imp.reload(ew)
ax = plt.subplot(2,1,2)
plt.title("B", loc='left', fontweight="bold")
ew.plot_median_wearing_effect(df, exp_trace, "exp", ax)

plt.tight_layout()
plt.savefig("../outputs/main_results_vertical.pdf")

# What is the empirical wearing effect in this window?

In [None]:
imp.reload(ew)

df = masks_object.df

#ew.plot_actual_wearing_effect(df, exp_trace, "exp")
#ew.plot_max_wearing_effect(df, exp_trace, "exp", ax)

maxes = {c : df.loc[c].percent_mc.max()  for c in df.reset_index().country.unique()}

g = df.reset_index()[["country", "percent_mc"]].groupby("country")
ranges = g.max("percent_mc") - g.min("percent_mc")
print("Average change in wearing in window:", round(ranges.median().iloc[0], 3) )

# Fig 3: R reduction over wearing

In [None]:
def exp_reduction_vector(a, x) :
    reductions = 1 - np.exp((-1.0) * a * x)
    return reductions

def relu(x):
    return np.maximum(0, x)

def ll_reduction_vector(a, x) :
    w = a * x 
    reductions = - np.log(relu(1 - w))
    return reductions


def q2_reduction_vector(alphas, x) :
    w = alphas[0] * x + alphas[1] * x**2
    reductions = - np.log(relu(1 - w))
    return reductions


def plot_param(tr, r, t, ax):
    alpha = tr.Wearing_Alpha.mean()
    
    if tr.Wearing_Alpha.shape[1] > 1 :
        lu0, m0, hi0 = np.percentile(tr.Wearing_Alpha[0], [2.5, 50, 97.5])
        lu1, m1, hi1 = np.percentile(tr.Wearing_Alpha[1], [2.5, 50, 97.5])
        alpha = [m0, m1]
        lu = [lu0, lu1]
        hi = [hi0, hi1]
    else:
        lu, alpha, hi = np.percentile(tr.Wearing_Alpha, [2.5, 50, 97.5])
    
    i = data.CMs.index("percent_mc")
    x = np.linspace(0, 1, 1000)
    y = r(alpha, x)
    
    #plt.plot(x, y)
    ax.plot(x * 100, y * 100)
    
    los = r(lu, x) 
    his = r(hi, x)
    ax.fill_between(x * 100, los * 100, his * 100, alpha=0.2)
    
    ax.set_title(t, fontsize=10)
    ax.set_ylim(0,60)
    ax.set_xlim(0,100)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))

f, (ax1,ax2,ax3) = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(5,7), dpi=400)
plot_param(exp_trace, exp_reduction_vector, "Exponential feature", ax1)
#plt.set_ylabel("", fontsize=16)
f.text(0.5, -0.02, '% mask wearing', ha='center', fontsize=16)
f.text(-0.02, 0.5, "Reduction in R", va='center', fontsize=16, rotation='vertical')

plot_param(q2_trace, q2_reduction_vector, "Quadratic feature", ax2)
plot_param(ll_trace, ll_reduction_vector, "Linear feature", ax3)
plt.tight_layout()
plt.savefig("../outputs/wearing_paramets.pdf", bbox_inches='tight')

# Max mandate effect on wearing in first wave

In [None]:
import imp
imp.reload(yg)


yg.plot_earliest_mandates_against_wearing()

yg.get_before_after_mandate_change()


plt.savefig("../outputs/yg_first_wave.pdf", bbox_inches='tight')

# Posterior correlations

In [None]:
sns.set(style="ticks", font='DejaVu Serif', font_scale=0.3)


def post_alpha_corrs(tr, w, ax):
    if w == "Wearing":
        m = "WearingReduction"
    else :
        m = "MandateReduction"
    df_exp = pm.trace_to_dataframe(tr, varnames = ['CMReduction', m, 'MobilityReduction'])

    df_exp.columns = npi_cols[:-2] + [m, "mobility"]
    corrs = df_exp.corr()
    # Generate a mask for the upper triangle
    mask = np.zeros_like(corrs, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(
        corrs,
        mask=mask,
        #cmap=cmap,
        #vmax=1,
        linewidths=0.5,
        annot=True,
        fmt=".2f",
        vmin=-1, 
        vmax=1,
        center=0,
        cbar=False,
        ax=ax,
    );
    #ax.set_ylim(0, -1)
    #ax.yaxis.get_major_ticks()[0].label.visible = False
    ax.axes.get_xticklabels()[0].set_visible(False)
    
    

f, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.5, 7), dpi=700)


post_alpha_corrs(exp_trace, w="Wearing", ax=ax1)
ax1.set_title(f"Wearing model, posterior correlations in reductions", fontsize=6)

post_alpha_corrs(m_trace, w="Mandate", ax=ax2)
ax2.set_title(f"\n\nMandate model, posterior correlations in reductions", fontsize=6)

f.subplots_adjust(hspace=20)
plt.tight_layout()



plt.savefig("../outputs/posterior_corrs_mandate.pdf", bbox_inches='tight')


# UMD facts

In [None]:
Ds = pd.date_range("2020-05-01", "2020-09-21", freq="D")
wearing = pd.read_csv(
        "../data/raw/umd/umd_national_wearing.csv",
        parse_dates=["survey_date"],
        infer_datetime_format=True,
    ).drop_duplicates()
wearing_windowed = wearing[(wearing.survey_date >= Ds[0]) & (wearing.survey_date <= Ds[-1])]

wearing_windowed.sample_size.sum()


wearing2021 = pd.read_csv(
        "../data/raw/umd/umd_national_wearing_2021.csv",
        parse_dates=["survey_date"],
        infer_datetime_format=True,
    ).drop_duplicates()


wearing = pd.concat([wearing, wearing2021])

In [None]:
wearing.percent_mc.median()

In [None]:
plt.rcParams['font.size'] = 8
sns.reset_orig()
plt.rcParams['font.family'] = "DejaVu Serif"
# y = wearing[["percent_mc", "survey_date"]].groupby("survey_date").median()
# plt.plot(y)
# lo = wearing[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.025)
# hi = wearing[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.975)
# plt.plot(y, color="blue")
# #plt.fill_between(lo.index, lo.percent_mc, hi.percent_mc, alpha=0.1)
# plt.title("World median wearing percentage")
# plt.ylabel("% wearing")
# plt.show()


fig = plt.figure(figsize=(PNAS_WIDTH_INCHES, 2.2), dpi=400)

vax_countries = ["Israel", "United Kingdom", "United States", "Canada", "Chile", "Hungary", "Germany", "Italy", "France", "Mongolia", "Uruguay", "Qatar", "Finland", "Belgium", "Italy", "Spain", "Netherlands", "Iceland", "Bahrain", "Bhutan", "Cyprus", "Malta", "United Arab Emirates", "Austria"]
vax = wearing[wearing.country.isin(vax_countries)]
vax.percent_mc = vax.percent_mc.rolling(3).mean()

y = vax[["percent_mc", "survey_date"]].groupby("survey_date").median()
lu = vax[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.25)
hi = vax[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.75)
# plt.plot(y * 100, label="> 40% vaccinated")
# plt.plot(lu.percent_mc * 100, alpha=0.5, color="blue")
# plt.plot(hi.percent_mc * 100, alpha=0.5, color="blue")
#plt.fill_between(lu.index, lu.percent_mc * 100, hi.percent_mc * 100, alpha=0.1)

notvax = wearing[~wearing.country.isin(vax_countries)]
notvax.percent_mc = notvax.percent_mc.rolling(3).mean()

#notvax.percent_mc = notvax.percent_mc.interpolate(limit_direction='backward', limit=2)
y = notvax[["percent_mc", "survey_date"]].groupby("survey_date").median()
y.percent_mc.iloc[0:3] = 0.654529
lu = notvax[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.25)
lu.percent_mc.iloc[0:3] = 0.52964
hi = notvax[["percent_mc", "survey_date"]].groupby("survey_date").quantile(0.75)
# plt.title("World median wearing")
plt.plot(y * 100, label="median",  linewidth=0.5)#"< 40% vaccinated")
plt.fill_between(lu.index, lu.percent_mc * 100, hi.percent_mc * 100, alpha=0.2, label="50% CI")
ax = plt.gca()
ax.yaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))
plt.ylim(40, 100)
#plt.xlim(y.index[0], y.index[-1])
#plt.xlim()


plt.ylabel("% wearing", fontsize=10)
plt.yticks(fontsize=8)
plt.xlabel("", fontsize=10)
plt.xticks(fontsize=8)
plt.legend(loc="lower right", frameon=False, fontsize=8)


import matplotlib.dates as mdates

locator = mdates.MonthLocator()  # every month
fmt = mdates.DateFormatter("%b'%y")
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(fmt)

# for label in ax.xaxis.get_ticklabels()[1:][::4]:
#     label.set_visible(False)

# for label in ax.xaxis.get_ticklabels()[1:][::2]:
#     label.set_visible(False)
    
for i, label in enumerate(ax.xaxis.get_ticklabels()[1:]):
    if i % 2 == 0:
        label.set_visible(False)

for i, label in enumerate(ax.xaxis.get_ticklabels()[1:]):
    if i % 2 == 0:
        label.set_visible(False)
    
#plt.rcParams['axes.linewidth'] = 0.001
#[i.set_linewidth(0.1) for i in ax.spines()]
    
plt.savefig("../outputs/world_wearing_21.pdf", bbox_inches="tight")


In [None]:
lu

In [None]:
vax = wearing[wearing.country.isin(vax_countries)]
vax.percent_mc = vax.percent_mc.rolling(7).mean()


notvax = wearing[~wearing.country.isin(vax_countries)]
notvax.percent_mc = notvax.percent_mc.rolling(7).mean()


fig = plt.figure(figsize=(PNAS_WIDTH_INCHES, 1.5), dpi=400)

for c in vax.country.unique() :
    y = vax[vax.country == c][["percent_mc", "survey_date"]].groupby("survey_date").median()
    plt.plot(y * 100, alpha=0.5, color="blue")

y = vax[["percent_mc", "survey_date"]].groupby("survey_date").median()
plt.plot(y * 100, color="blue")
    
for c in notvax.country.unique() :
    y = notvax[notvax.country == c][["percent_mc", "survey_date"]].groupby("survey_date").median()
    plt.plot(y * 100, alpha=0.5, color="#CCC")

y = notvax[["percent_mc", "survey_date"]].groupby("survey_date").median()
plt.plot(y * 100, color="grey")
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))
plt.legend(loc="lower right")

plt.savefig("../outputs/world_wearing_21.pdf", bbox_inches="tight")

# Holdout nonwearers

In [None]:
y.iloc[-7:-1].mean()

y = wearing[["percent_mc", "country"]].groupby("country").max()
#plt.bar(y.index, y.percent_mc)
y.sort_values("percent_mc")

y.mean()

# How much did the nonvax drop in May 2021?

In [None]:
may_notvax = notvax[["percent_mc", "survey_date"]].groupby("survey_date").median().tail(31)

earlyMay = may_notvax.head(3).mean() 
lateMay = may_notvax.tail(2).mean() 
1 - (earlyMay / lateMay)

#may_notvax.head(12).mean() 

# How many NPIs are active at the start? Which are "reopening"?

In [None]:
#df = df.reset_index()
start_df = df[df.date == "2020-05-01"]
npis = [c for c in df.columns[4:15]]


ps = []
for npi in npis:
    pct_1 = start_df[npi].value_counts(normalize=True).loc[1]
    print(npi, f"{pct_1:.1f}")
    ps.append(pct_1)
    
np.mean([0.9, 0.8, 0.6, 0.5, 0.8])


# Recommendations

In [None]:
def threshold(df, col, t):
    NATIONAL = 1
    code = col[:2]

    return (df[col] >= t) & (df[f"{code}_Flag"] == NATIONAL)

def load_and_clean_wearing():
    wearing = pd.read_csv(
        "../data/raw/umd/umd_national_wearing.csv",
        parse_dates=["survey_date"],
        infer_datetime_format=True,
    ).drop_duplicates()
    wearing = wearing[(wearing.survey_date >= Ds[0]) & (wearing.survey_date <= Ds[-1])]
    cols = ["country", "survey_date", "percent_mc"]
    wearing = wearing[cols]
    cols = ["country", "date", "percent_mc"]
    wearing.columns = cols

    # Append US
    us_wearing = load_and_clean_rader()
    us_wearing.columns = ["date", "country", "percent_mc"]
    us_wearing = us_wearing[cols]
    us_wearing = us_wearing.replace("Georgia", "Georgia-US")
    us_wearing = us_wearing.replace("District of Columbia (DC)", "District of Columbia")
    # Add dummy wearing back to 1st May
    us_wearing = add_dummy_wearing_us(us_wearing, backfill=True)
    wearing = pd.concat([wearing, us_wearing])

    return fill_missing_days(wearing)

def load_and_clean_rader(THRESHOLD=2, SMOOTH_RADER=True):  # or less
    DATA_IN = "../data/raw/"
    directory = DATA_IN + "rader/sm_cny_data_1_21_21.csv"
    us = pd.read_csv(directory)

    masks = [
        "likely_wear_mask_exercising_outside",
        "likely_wear_mask_grocery_shopping",
        "likely_wear_mask_visit_family_friends",
        "likely_wear_mask_workplace",
    ]
    # weights = ["weight_daily_national_13plus", "weight_state_weekly"]
    us = us[["response_date", "state"] + masks]  # + weights

    codes = pd.read_excel(DATA_IN + "rader/cny_sm_codebook_2_5_21.xls")
    num2name = codes[codes["column"] == "state"][["value", "label"]]
    us = pd.merge(us, num2name, left_on="state", right_on="value").drop(
        ["value", "state"], axis=1
    )
    us["response_date"] = pd.to_datetime(us["response_date"])
    us["percent_mc"] = mean_shop_work(us, THRESHOLD)

    us = (
        us[["response_date", "label", "percent_mc"]]
        .groupby(["response_date", "label"])
        .mean()
        .reset_index()
    )
    if SMOOTH_RADER:
        us = smooth_rader(us)

    return us


def mean_shop_work(df, THRESHOLD=2):
    venues = ["likely_wear_mask_grocery_shopping", "likely_wear_mask_workplace"]
    df["percent_mc"] = df[venues].mean(axis=1)

    return df["percent_mc"] <= THRESHOLD


def smooth_rader(df, win=7):
    for r in df.label.unique():
        s = df[df.label == r]
        s["percent_mc"] = smooth(s["percent_mc"], window_len=win)[: -win + 1]
        df[df.label == r] = s

    return df


def smooth(x, window_len=7):
    l = window_len
    s = np.r_[x[l - 1 : 0 : -1], x, x[-2 : -l - 1 : -1]]
    w = np.ones(window_len, "d")

    return np.convolve(w / w.sum(), s, mode="valid")

def join_ox_umd(oxcgrt, wearing, npi_cols):
    join = oxcgrt.merge(
        wearing,
        right_on=["country", "date"],
        left_on=["CountryName", "date"],
        suffixes=("", "_"),
    )  # , \
    # how='left')

    return join[npi_cols + ["country", "date", "ConfirmedCases", "ConfirmedDeaths"]]


def add_dummy_wearing_us(us, backfill=True):
    rader_start = us.date.iloc[0] - timedelta(days=1)
    fill_days = pd.date_range(Ds[0], rader_start, freq="D")

    Rs = us.country.unique()

    if backfill:
        for s in Rs:
            df = pd.DataFrame(columns=["date", "country", "percent_mc"])
            df.date = fill_days
            df.country = s
            fill = us.set_index(["country", "date"]).loc[s].percent_mc.iloc[0]
            df.percent_mc = fill
            us = pd.concat([df, us])
    # totally random dummy
    else:
        for s in us.country.unique():
            df = pd.DataFrame(columns=["date", "country", "percent_mc"])
            df.date = fill_days
            df.country = s
            df.percent_mc = np.random.random(len(df))
            us = pd.concat([df, us])

    us = us.sort_values(["date", "country"])
    return us


def fill_missing_days(df):
    df = df.set_index(["date", "country"])
    df = df.unstack(fill_value=-1).asfreq("D", fill_value=-1).stack().reset_index()
    df = df.replace(-1, np.nan)

    return interpolate_wearing_fwd_bwd(df)


def interpolate_wearing_fwd_bwd(df):
    regions = df.country.unique()
    cs = []

    for r in regions:
        c = df[df.country == r]
        c = c.set_index("date")
        c = c.interpolate(method="time", limit_direction="both").reset_index()
        cs.append(c)

    return pd.concat(cs)


npi_cols_raw_no_mob = [
            "C1_School closing",
            "C2_Workplace closing",
            "C4_Restrictions on gatherings",
            "C6_Stay at home requirements",
            "C7_Restrictions on internal movement",
        ] \
    + ["H6_Facial Coverings"] \
    + ["percent_mc"]


oxcgrt2 = oxcgrt.copy()
col = "H6_Facial Coverings"
# oxcgrt2["H6_Facial Coverings"] = (oxcgrt2[col] < 2) & (oxcgrt2[col] == 1) & (oxcgrt2[f"H6_Flag"] == 1) #
oxcgrt2["H6_Facial Coverings"] = threshold(oxcgrt2, "H6_Facial Coverings", 1)
wearing = load_and_clean_wearing()
join = join_ox_umd(oxcgrt2, wearing, npi_cols_raw_no_mob)
join = join.set_index(["country", "date"])
#[["CountryName", "H6_Facial Coverings", "ConfirmedCases"]]

In [None]:
join[["H6_Facial Coverings","percent_mc"]].corr()
import scipy
round(scipy.stats.pearsonr(join["H6_Facial Coverings"], join["percent_mc"])[0], 2), \
round(scipy.stats.spearmanr(join["H6_Facial Coverings"], join["percent_mc"])[0], 2)

In [None]:
df = mobility_data
summaries = mw.get_centred_summary(["Netherlands"], df)

fig, ax = plt.subplots()

c = "Netherlands"
cs = summaries[summaries.country == c]
ax.plot(cs.day, cs.percent_mc * 100)# label="median wearing")

if len(c) > 20:
    tsize = 10
ax.set_title(c, fontsize=10)


ax.axvline(x=0, color="black", linestyle="--")
ax.yaxis.set_major_formatter(mtick.PercentFormatter(decimals=0))
ax.set_ylim(0, 100)
ax.set_xlim(-20, 20)

fig.text(0.5, -0.02, 'Days since mandate', ha='center', fontsize=16)
fig.text(-0.02, 0.5, "% mask wearing", va='center', fontsize=16, rotation='vertical')
plt.tight_layout(pad=0.4)