In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
#### Instructions: completely re-run for both DATA sets

DATA = 'BraunerTE'
#DATA = 'Sharma'



default_alpha = None
if DATA == "Sharma":
    data = preprocess_data('../../d/modelSharma_dataSharma.csv')
    data.featurize(drop_npi_filter=[
        {"query": "Childcare Closed", "type": "equals"},
        {"query": "All Face-to-Face Businesses Closed", "type": "equals"},
    ])
    data.mask_new_variant(new_variant_fraction_fname='../../d/modelSharma_dataSharma_nuts3_new_variant_fraction.csv')
    data.mask_from_date('2021-01-09') 

    default_names = data.CMs
else:
    data = preprocess_data('../../d/modelSharma_dataBraunerTE.csv', start_date="2020-01-22", end_date="2020-05-30", household_feature_processing="raw")
    data.featurize(all_binary=True)
    data.mask_from_date("2020-05-30") 

    default_names = data.CMs
print(default_names)

In [None]:
cols = sns.color_palette('colorblind')

if DATA == "Sharma":
    corrected_names = [
        'Night clubs closed',
        'Gastronomy closed',
        'Leisure venues closed',
        'Retail and close-contact\nservices closed',
        'Night time curfew',
        'Primary schools closed',
        'Secondary schools closed',
        'Universities closed',
        'All public gatherings banned',
        'Public gatherings limited to 2 people', 
        'Public gatherings limited to ≤10 people',
        'Public gatherings limited to ≤30 people', 
        'Public gatherings limited to 2 households',
        'All household mixing in private banned',
        'Household mixing in private limited to 2 people', 
        'Household mixing in private limited to ≤10 people',
        'Household mixing in private limited to ≤30 people', 
        'Household mixing in private limited to 2 households',
        'Stricter mask-wearing\npolicy',
    ]

    grouped_npis = {
        'All non-essential\nbusinesses closed': {
            'npis': ['Retail Closed', 'Some Face-to-Face Businesses Closed', 
                    'Gastronomy Closed', 'Leisure Venues Closed'],
            'type': 'exclude',
            'color': cols[0],
            'main': True,
        },
        'Night clubs closed': {
            'npis': ['Some Face-to-Face Businesses Closed'],
            'type': "exclude",
            'color': cols[0],
            'main': False,
        },
        'Leisure and entertainment\nvenues closed': {
            'npis': ['Leisure Venues Closed'],
            'type': 'exclude',
            'color': cols[0],
            'main': False,
        },
        'Gastronomy closed': {
            'npis': ['Gastronomy Closed'],
            'type': 'exclude',
            'color': cols[0],
            'main': False,
        },
        'Retail and close-contact\nservices closed': {
            'npis': ['Retail Closed'],
            'type': 'exclude',
            'color': cols[0],
            'main': False,
        },
        'All gatherings banned': {
            'npis': ['Public Indoor Gathering Person Limit - 1',
                    'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    'Private Indoor Gathering Person Limit - 1',
                    'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[3],
            'main': True
        },
        'All gatherings limited to 2 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            "type": "exclude",
            'color': cols[3],
            'main': False
        },
        'All gatherings limited to ≤10 people\nfrom 2 households': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[3],
            'main': False
        },
        'All gatherings limited to ≤10 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    #'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    #'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[3],
            'main': False
        },
        'All gatherings limited to ≤30 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    #'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    #'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    #'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    #'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[3],
            'main': False
        },
        'All educational\ninstitutions closed': {
            'npis': ['Primary Schools Closed', 'Secondary Schools Closed', 'Universities Away'],
            'type': 'exclude',
            'color': cols[2],
            'main': True
        },
        'Night time curfew': {
            'npis': ['Curfew'],
            'type': 'exclude',
            'color': cols[1],
            'main': True
        },
        'Stricter mask-wearing\npolicy': {
            'npis': ['Mandatory Mask Wearing >= 3'],
            'type': 'exclude',
            'color': cols[1],
            'main': True
        },
        'All public gatherings banned': {
            'npis': ['Public Indoor Gathering Person Limit - 1',
                    'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    ],
            'type': 'exclude',
            'color': cols[4],
        },
        'Public gatherings limited to 2 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    ],
            "type": "exclude",
            'color': cols[4],
        },
        'Public gatherings limited to ≤10 people\nfrom 2 households': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    ],
            'type': 'exclude',
            'color': cols[4],
        },
        'Public gatherings limited to ≤10 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    ],
            'type': 'exclude',
            'color': cols[4],
        },
        'Public gatherings limited to ≤30 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    #'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    ],
            'type': 'exclude',
            'color': cols[4],
        },
        'All household mixing in private banned': {
            'npis': ['Private Indoor Gathering Person Limit - 1',
                    'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[5],
        },
        'Household mixing in private\nlimited to 2 people': {
            'npis': [#'Private Indoor Gathering Person Limit - 1',
                    'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            "type": "exclude",
            'color': cols[5],
        },
        'Household mixing in private\nlimited to ≤10 people from 2 households': {
            'npis': [#'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[5],
        },
        'Household mixing in private\nlimited to ≤10 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    #'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[5],
        },
        'Household mixing in private\nlimited to ≤30 people': {
            'npis': [#'Public Indoor Gathering Person Limit - 1',
                    #'Public Indoor Gathering Person Limit - 2',
                    #'Public Indoor Gathering Person Limit - 10',
                    #'Extra Public Indoor Household Limit',
                    #'Private Indoor Gathering Person Limit - 1',
                    #'Private Indoor Gathering Person Limit - 2',
                    #'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    #'Extra Private Indoor Household Limit'
                    ],
            'type': 'exclude',
            'color': cols[5],
        },
        'Combined effect of all NPIs': {
            'npis': ['Some Face-to-Face Businesses Closed',
                    'Gastronomy Closed',
                    'Leisure Venues Closed',
                    'Retail Closed',
                    'Curfew',
                    'Primary Schools Closed',
                    'Secondary Schools Closed',
                    'Universities Away',
                    'Public Indoor Gathering Person Limit - 1',
                    'Public Indoor Gathering Person Limit - 2',
                    'Public Indoor Gathering Person Limit - 10',
                    'Public Indoor Gathering Person Limit - 30',
                    'Extra Public Indoor Household Limit',
                    'Private Indoor Gathering Person Limit - 1',
                    'Private Indoor Gathering Person Limit - 2',
                    'Private Indoor Gathering Person Limit - 10',
                    'Private Indoor Gathering Person Limit - 30',
                    'Extra Private Indoor Household Limit',
                    'Mandatory Mask Wearing >= 3'],
            "type": 'include'
        }
    }

else:
    corrected_names = ['Gatherings <1000', 'Gatherings <100', 'Gatherings <10',
        'Some Businesses Suspended', 'Most Businesses Suspended', 'School Closure', 'University Closure', 'Stay Home Order']
    grouped_npis = {
        'Gatherings limited to\n1000 people or less': {
            'npis': ['Gatherings <1000'],
            'type': 'exclude', 'color': cols[0],
        },
        'Gatherings limited to\n100 people or less': {
            'npis': ['Gatherings <1000', 'Gatherings <100'],
            'type': 'exclude', 'color': cols[0],
        },
        'Gatherings limited to\n10 people or less': {
            'npis': ['Gatherings <1000', 'Gatherings <100', 'Gatherings <10'],
            'type': 'exclude', 'color': cols[0],
        },
        'Some businesses closed': {
            'npis': ['Some Businesses Suspended'],
            'type': 'exclude', 'color': cols[1],
        },
        'Most nonessential\nbusinesses closed': {
            'npis': ['Some Businesses Suspended', 'Most Businesses Suspended'],
            'type': 'exclude', 'color': cols[1],
        },
        'Schools and universities\nclosed': {
            'npis': ['School Closure', 'University Closure'],
            'type': 'exclude', 'color': cols[2],
        },
        'Additional benefit of\nstay-at-home order': {
            'npis': ['Stay Home Order'],
            'type': 'exclude', 'color': cols[0],
        },
        'Combined effect of all NPIs': {
            'npis': ['Gatherings <1000', 'Gatherings <100', 'Gatherings <10',
                     'Some Businesses Suspended', 'Most Businesses Suspended',
                     'School Closure', 'University Closure', 'Stay Home Order'],
            "type": 'include'
        }
    }


In [None]:
def intervention_prior_labeler(d): 
    if float(d['exp_config']['intervention_prior']['scale']) == 20.0:
        return  "AsymmetricLaplace(0, 0.5, 20)"
    elif d['exp_config']['intervention_prior']['type'] == "normal":
        return f"Normal(0, {d['exp_config']['intervention_prior']['scale']}$^2$)"
    elif float(d['exp_config']['intervention_prior']['scale']) == 0.15:
        return "HalfNormal(0, 0.15$^2$)"

def get_all_experiments(path):
    experiments = []
    for subdir, dirs, files in os.walk(f'{path}'):
        for filename in files:
#             if filename.endswith('.yaml'):
#                 filepath = subdir + os.sep + filename
#                 print(filepath)
#                 with open(filepath) as f:
#                     try:
#                         data = yaml.safe_load(f)
#                         experiments.append(data)
#                     except:
#                         print('failed to load f')
            if filename.endswith('.json'):
                filepath = subdir + os.sep + filename
                print(filepath)
                with open(filepath) as f:
                    try:
                        data = json.load(f)
                        data["MODEL"] = re.search('model(.*)_', data['model_config_name']).groups()[0]
                        data["DATA"] = re.search('data(.*)', data['model_config_name']).groups()[0]
                        data["FILENAME"] = filename
                        experiments.append(data)
                    except:
                        print(f'failed to load {filename}')
    return experiments

def filter_by_exp_tag(experiments, exp_tag, exp_info):
    filtered = []
    for experiment in experiments:
        if experiment['exp_tag'] == exp_tag:
            filtered.append(experiment)
        elif "exp_matcher" in exp_info and exp_info["exp_matcher"](experiment):
            filtered.append(experiment)

    return filtered

class experiment_type():
    def __init__(self, experiments, exp_info, tag):
        self.exp_info = exp_info
        self.experiments = experiments
        self.exp_info["tag"] = tag
        
        if "alpha_i" in list(self.experiments[0].keys()):
            self.experiments.sort(key=lambda x: np.median(np.array(x['alpha_i'])[:, 0]))

def get_unique_exp_tags(experiments):
    return list(np.unique([exp['exp_tag'] for exp in experiments]))

def make_all_experiment_classes(all_experiments, all_exp_info):
    classes = []
    for tag, info in all_exp_info.items():
        filtered_exps = filter_by_exp_tag(all_experiments, tag, info)
        if filtered_exps:
            exp_info = all_exp_info[tag]
            classes.append(experiment_type(filtered_exps, exp_info, tag))
        
    return classes

In [None]:
def add_trace_to_plot(samples, y_off, col, label, alpha, width, npi_comb_dict, cm_names, size=6, extra={}):
    comb_effects, new_names = combine_npi_samples(npi_comb_dict, samples, cm_names)    
    comb_effects = 100*(1-np.exp(-comb_effects))
    npi_order = list(npi_comb_dict.keys())
    nF = len(npi_order)

    for exn, exvals in extra.items():
        comb_effects = np.concatenate([comb_effects, np.reshape(exvals, (-1, 1))], axis=1)
        new_names.append(exn)
        npi_order.append(exn)

    y_vals = -np.array([npi_order.index(name) for name in new_names])
    plt.plot([100], [100], color=col, linewidth=1, alpha=alpha, label=label)

    li, lq, m, uq, ui = np.percentile(comb_effects, [2.5, 25, 50, 75, 97.5], axis=0)
    plt.scatter(m, y_vals+y_off, marker="o", color=col, s=size, alpha=alpha, facecolor='white', zorder=3, linewidth=width/2)
    for cm in range(len(new_names)):
        plt.plot([li[cm], ui[cm]], [y_vals[cm]+y_off, y_vals[cm]+y_off], color=col, alpha=alpha*0.25, linewidth=width, zorder=2)
        plt.plot([lq[cm], uq[cm]], [y_vals[cm]+y_off, y_vals[cm]+y_off], color=col, alpha=alpha*0.75, linewidth=width, zorder=2)


def setup_plot(experiment_class, npi_comb_dict, y_ticks = True, xlabel=True, x_lims=(-25, 50), newfig=True):
    if newfig:
        plt.figure(figsize=(4, 6), dpi=400)
        
    ax = plt.gca()
    x_min, x_max = x_lims
    
    npi_order = list(npi_comb_dict.keys())
    plt.plot([0, 0], [1, -(len(npi_order)+2)], "--k", linewidth=0.5)

    xrange = np.array([x_min, x_max])
    
    for height in range(0, len(npi_order)+2, 2):
        plt.fill_between(xrange, -(height-0.5), -(height+0.5), color="silver", alpha=0.25, linewidth=0)
    xtick_vals = [-25, 0, 25, 50, 75, 100]
    xtick_str = [f"{x:.0f}%" for x in xtick_vals]
    
    if y_ticks:
        plt.yticks(-np.arange(len(npi_order)), npi_order, fontsize=6)
    else:
        plt.yticks([])
   
    plt.xticks(xtick_vals, xtick_str, fontsize=8)
    plt.xlim([x_min, x_max])
    plt.ylim([-(len(npi_order) - 0.25), 0.75])
    
    plt.plot([-100, 100], [-len(npi_order)+11.5, -len(npi_order)+11.5], 'k')

    if xlabel:
        plt.xlabel("Reduction in R", fontsize=8)

colors = [*sns.color_palette("colorblind"), *sns.color_palette("dark")]

def plot_experiment_class(experiment_class, npi_comb_dict, x_lims=None, default_res=None, default_names=None, width=1, newfig=True, prefix=""):
    default_label = experiment_class.exp_info["default_label"]
    labeler = experiment_class.exp_info["labeler"]
    title = experiment_class.exp_info["title"]
    if callable(title):
        title = title(experiment_class.experiments[0])
    add = experiment_class.exp_info.get("add", ())

    extra_effects = []
    for e in experiment_class.experiments:
        extra_effects.append({})
        for an in add:
            if an == "Seasonality peak-to-through":
                b1 = np.array(e["seasonality_beta1"])
                extra_effects[-1]["Seasonality peak-to-through\nreduction"] = 100 * (1-(1-b1)/(1+b1))
            if an == "Seasonality amplitude":
                extra_effects[-1]["Seasonality amplitude γ\n(in percent)"] = 100 * np.array(e["seasonality_beta1"])
    #print(extra_effects)
    npi_comb_dict_ext = dict(npi_comb_dict)
    for en in extra_effects[-1].keys():
        npi_comb_dict_ext[en] = None

    setup_plot(experiment_class, npi_comb_dict_ext, x_lims=x_lims, newfig=newfig)
    y_off = -np.linspace(-0.3, 0.3, len(experiment_class.experiments)+1)
    for i, trace in enumerate(experiment_class.experiments):
        add_trace_to_plot(np.array(trace['alpha_i']), y_off[i], colors[i], labeler(trace), alpha=1, 
                            width=width, npi_comb_dict=npi_comb_dict, cm_names=trace['cm_names'], extra=extra_effects[i])

    if default_res is not None:
        add_trace_to_plot(default_res, y_off[-1], "k", default_label, alpha=1, 
                            width=width, npi_comb_dict=npi_comb_dict, cm_names=default_names)

    plt.legend(shadow=True, fancybox=True, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
    plt.title(title, fontsize=8)
    plt.savefig(f'figures/appendix/sensitivity/Fig_{prefix}{experiment_class.exp_info["tag"]}.pdf', bbox_inches='tight')

In [None]:
per_model_exp_info = {
    "seasonality_max_R_day_fixed": {
        "title": lambda d: f"{d['MODEL']} model, {d['DATA']} data sensitivity to peak seasonality day",
        "labeler": lambda d: (datetime.datetime(2020, 1, 1) + datetime.timedelta(days=int(d["model_kwargs"]["max_R_day_prior"]["value"])-1)).strftime("Peak: %d %b"),
        "default_label": f"No seasonality",
        "add": ["Seasonality amplitude", "Seasonality peak-to-through"],
    },
    "seasonality_basic_R_prior": {
        "title": lambda d: f"{d['MODEL']} model, {d['DATA']} data sensitivity to inital R0",
        "labeler": lambda d: f'Initial R0 = {d["model_kwargs"]["basic_R_prior"]["mean"]:.2f}',
        "default_label": f"No seasonality",
        "add": ["Seasonality amplitude", "Seasonality peak-to-through"],
    },
}


def NPI_effects_labeler(d):
    w = ""
    if d['DATA'] == "Brauner":
        w = " (all)"
    if d['DATA'] == "BraunerTE":
        w = " (TE)"
    return f'{d["MODEL"]} {"seasonal" if "seasonality_beta1" in d else "original"}{w}'

cross_model_exp_info = {
    "NPI_effects": {
        "title": f"{DATA.rstrip('TE')} et al. dataset NPI effects",
        "labeler": NPI_effects_labeler,
        "exp_matcher": lambda d: d['exp_tag'] == "default" or (d['exp_tag'] == "seasonality_basic_R_prior" and d["model_kwargs"]["basic_R_prior"]["mean"] in (3.3, 1.35)),
        "default_label": "undefined",
    },    
}


In [None]:
all_experiments = get_all_experiments(f'../../sensitivity_final/')
all_experiments = [e for e in all_experiments if e['DATA'].rstrip('TE') == DATA.rstrip('TE')]
print(["seasonality_beta1" in e for e in all_experiments])

print(f"Filtered down to: {len(all_experiments)} for data {DATA}")

for MODEL in ["Brauner", "Sharma"]:
    experiments = [e for e in all_experiments if f"model{MODEL}" in e['model_config_name']]
    print(f"Filtered down to: {len(experiments)} for model {MODEL}")
    experiment_classes = make_all_experiment_classes(experiments, per_model_exp_info)
    print("Classes:", {c.exp_info["tag"]: len(c.experiments) for c in experiment_classes})
    for experiment_class in experiment_classes:
        plt.figure(figsize=(4, 8 if DATA=="Sharma" else 4), dpi=400)
        plot_experiment_class(experiment_class, grouped_npis, (-25, 100), default_alpha, default_names, prefix=f"model{MODEL}_data{DATA}_", newfig=False)

experiment_classes = make_all_experiment_classes(all_experiments, cross_model_exp_info)
print("Classes:", {c.exp_info["tag"]: [e['FILENAME'] for e in c.experiments] for c in experiment_classes})
for experiment_class in experiment_classes:
    plt.figure(figsize=(4, 8 if DATA=="Sharma" else 4), dpi=400)
    plot_experiment_class(experiment_class, grouped_npis, (-25, 100), default_alpha, default_names, prefix=f"data{DATA}_", newfig=False)
