In [1]:
import pickle
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

import covid19sim

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
%matplotlib inline

import sys
sys.path.append("../plots")
from plot_rt import PlotRt
from utils import plot_intervention

In [2]:
filenames_transformer = [
    "../src/covid19sim/transformer/sim_v2_people-1000_days-45_init-0.001_seed-1000_20200515-153325/tracker_data_n_1000_seed_1000_20200515-185404_transformer-low-tests.pkl",
    "../src/covid19sim/transformer/sim_v2_people-1000_days-45_init-0.001_seed-1001_20200515-153325/tracker_data_n_1000_seed_1001_20200515-185339_transformer-low-tests.pkl",
    "../src/covid19sim/transformer/sim_v2_people-1000_days-45_init-0.001_seed-1002_20200515-153325/tracker_data_n_1000_seed_1002_20200515-190134_transformer-low-tests.pkl",
]

# filenames_digital = [
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1000_20200511-220418_digital.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1001_20200511-220429_digital.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1002_20200511-220238_digital.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1003_20200511-220305_digital.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1004_20200511-220353_digital.pkl"
# ]

# filenames_lockdown = [
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1000_20200511-221006_lockdown.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1001_20200511-220928_lockdown.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1002_20200511-220932_lockdown.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1003_20200511-221013_lockdown.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1004_20200511-221000_lockdown.pkl"
# ]

# filenames_unmitigated = [
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1000_20200511-215150_unmitigated.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1001_20200511-215155_unmitigated.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1002_20200511-215206_unmitigated.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1003_20200511-215209_unmitigated.pkl",
#     "../src/covid19sim/tune/exp_google_1000/tracker_data_n_1000_seed_1004_20200511-215220_unmitigated.pkl"
# ]

In [3]:
def plot_line_with_bounds(df, ax, color, label, mobility=False, **kwargs):
    # params
    linestyle = kwargs.get("linestyle", "-")
    alpha = kwargs.get("alpha", 1.0)
    marker = kwargs.get("marker", None)
    markersize = kwargs.get("markersize", 1)
    linewidth = 2
    
    index = np.array(list(range(df.shape[0])))
    mean = df.mean(axis=1)
    lows = mean - 0.80*df.std(axis=1)
    highs = mean + 0.80*df.std(axis=1)
    lowfn = interp1d(index, lows, bounds_error=False, fill_value='extrapolate')
    highfn = interp1d(index, highs, bounds_error=False, fill_value='extrapolate')
    
    ax.plot(mean, color=color, alpha=0.8*alpha, linestyle=linestyle, 
                linewidth=linewidth, label=label, marker=marker, ms=markersize)
    ax.fill_between(index, lowfn(index), highfn(index), color=color, alpha=.05, lw=0, zorder=3)
    
    return ax
    

In [79]:
def plot_all(filenames, labels, title, end_day = 20):
    colormap = ["red", "orange", "blue", "green"]

    # constants
    _tmp = pickle.load(open(filenames[-1][0], "rb"))
    intervention_day = _tmp['intervention_day']
    n_init_infected = _tmp['n_init_infected']
    n_humans = _tmp['n_humans']

    # plot
    fig, (ax, rax) = plt.subplots(nrows=2, ncols=1, figsize=(15,14), sharex=True, dpi=500)
    dual_ax = None #ax.twinx() # mobility

    for c, filelist in enumerate(filenames):
        mobility = False #"unmitigated" not in filelist[0] or True
        ax, dual_ax = plot_group_cases(filelist, ax, dual_ax, mobility=mobility, color=colormap[c], end_day=end_day, label=labels[c])

    for c,filelist in enumerate(filenames):
        rax = plot_group_R(filelist, rax, color=colormap[c], end_day=end_day, label=labels[c])


    # intervention
    if intervention_day > 0:
        ax.axvline(x=intervention_day-1, linestyle="-.", linewidth=3, alpha=0.6)
        ax.annotate("Intervention", xy=(intervention_day-1.8, 4), xytext=(intervention_day-1.8, 2.), size=20, rotation="vertical")

        # Rt
        rax.axvline(x=intervention_day-1, linestyle="-.", linewidth=3, alpha=0.6)
        rax.axhline(y=1.0, linestyle="-.", linewidth=3, color="green", alpha=0.3)
        rax.annotate("$R_t$ = 1.0", xy=(intervention_day, 1.0), xytext=(intervention_day -5 , 1.10), size=20, rotation="horizontal")

    # doubling every three days
    y_vals = [1.0*min(n_init_infected * pow(2, y/3), n_humans)/n_humans for y in range(0, end_day+1)]
#     ax.plot(range(1, len(y_vals)+1), y_vals, '-.', color="gray", alpha=0.3)

    # grid and ticks
    for x in [ax, rax]:
        x.grid(True, axis='x', alpha=0.3)
        x.grid(True, axis='y', alpha=0.3)
        x.tick_params(labelsize=25)

    rax.set_ylim(0, 4)
#     dual_ax.set_ylim(0, 10)
#     dual_ax.tick_params(labelsize=25)

    # legends
    legends = ax.get_legend_handles_labels()
    legends, legend_labels = legends[0][:len(labels)], legends[1][:len(labels)]
#     line1 = Line2D([0], [0], color="black", linewidth=2, linestyle='--', label="mobility")
#     line2 = Line2D([0], [0], color="black", linewidth=2, linestyle='-', label="% infected")
    ax.legend(legends, legend_labels, prop={"size":20}, loc="upper left")

    legends = rax.get_legend_handles_labels()
    rax.legend(legends[0], legends[1],  prop={"size":20}, loc="upper right")

    # labels 
    ax.set_ylabel("% Population Infected", fontsize=25, labelpad=20)
#     dual_ax.set_ylabel("mobility", fontsize=25, labelpad=30, rotation=270)
    
    rax.set_ylabel("$R_t$", fontsize=25, labelpad=20)

    # plot
    rax.set_xlabel("Days since outbreak", fontsize=20)
    rax.set_xticks(np.arange(0,20,2))
    fig.suptitle(title, fontsize=30 )
    
    return fig


In [80]:
def plot_all_in_french(filenames, labels, title, end_day = 20):
    colormap = ["red", "orange", "blue", "green"]
    # constants
    _tmp = pickle.load(open(filenames[-1][0], "rb"))
    intervention_day = _tmp['intervention_day']
    n_init_infected = _tmp['n_init_infected']
    n_humans = _tmp['n_humans']
    # plot
    fig, (ax, rax) = plt.subplots(nrows=2, ncols=1, figsize=(15,14), sharex=True, dpi=500)
    dual_ax = None
    for c, filelist in enumerate(filenames):
        mobility = False
        ax, dual_ax = plot_group_cases(filelist, ax, dual_ax, mobility=mobility, color=colormap[c], end_day=end_day, label=labels[c])
    for c,filelist in enumerate(filenames):
        rax = plot_group_R(filelist, rax, color=colormap[c], end_day=end_day, label=labels[c])
    # intervention
    if intervention_day > 0:
        ax.axvline(x=intervention_day-1, linestyle="-.", linewidth=3, alpha=0.6)
        xy = (intervention_day-1.8, 4)
        ax.annotate("Intervention", xy=xy, xytext=xy, size=20, rotation="vertical")
        # Rt
        rax.axvline(x=intervention_day-1, linestyle="-.", linewidth=3, alpha=0.6)
        rax.axhline(y=1.0, linestyle="-.", linewidth=3, color="green", alpha=0.3)
        rax.annotate("$R_t$ = 1.0", xy=(intervention_day, 1.0), xytext=(intervention_day -5 , 1.10), size=20, rotation="horizontal")
    # doubling every three days
    y_vals = [1.0*min(n_init_infected * pow(2, y/3), n_humans)/n_humans for y in range(0, end_day+1)]
#     ax.plot(range(1, len(y_vals)+1), y_vals, '-.', color="gray", alpha=0.3)
    # grid and ticks
    for x in [ax, rax]:
        x.grid(True, axis='x', alpha=0.3)
        x.grid(True, axis='y', alpha=0.3)
        x.tick_params(labelsize=25)
    rax.set_ylim(0, 4)
    # legends
    legends = ax.get_legend_handles_labels()
    legends, legend_labels = legends[0][:len(labels)], legends[1][:len(labels)]
    line2 = Line2D([0], [0], color="black", linewidth=2, linestyle='-', label="% infected")
    ax.legend(legends, legend_labels, prop={"size":20}, loc="upper left")
    legends = rax.get_legend_handles_labels()
    rax.legend(legends[0], legends[1],  prop={"size":20}, loc="upper right")
    # labels 
    ax.set_ylabel("% de cas dans la population", fontsize=25, labelpad=20)
    rax.set_ylabel("$R_t$", fontsize=25, labelpad=20)
    # plot
    rax.set_xlabel("Nombre de jours", fontsize=20)
    rax.set_xticks(np.arange(0,20,2))
    fig.suptitle(title, fontsize=30 )
    return fig




In [81]:
# assumption - all lists have same size
def plot_group_cases(filenames, ax, dual_ax, mobility=True, color="red", end_day=None, label=""):
    cases, mobile = [], []
    if not label:
        label = filenames[0][:-3].split("_")[-1].split(".")[0]

    for i, filename in enumerate(filenames):
        with open(filename, "rb") as f:
            data = pickle.load(f)
        n_init_infected = data['n_init_infected']
        n_humans = data['n_humans']

        cases_per_day = np.cumsum(data['cases_per_day'])[:end_day]
        cases_per_day += n_init_infected
        cases_per_day = cases_per_day / n_humans * 100
        cases.append(cases_per_day)

        mobile.append(data['outside_daily_contacts'][:end_day])

    cases = pd.DataFrame(cases).transpose()
    
    ax = plot_line_with_bounds(cases, ax, color, label)
    if mobility:
        mobile = pd.DataFrame(mobile).transpose()
        dual_ax = plot_line_with_bounds(mobile, dual_ax, color, label, mobility=True, linestyle="--", alpha=0.3)

    return ax, dual_ax

def plot_group_R(filenames, ax, color="red", end_day=None, label=""):
    if not label:
        label = filenames[0][:-3].split("_")[-1].split(".")[0]
    
    Rt = []
    for i, filename in enumerate(filenames):
        with open(filename, "rb") as f:
            data = pickle.load(f)
        
        output_r = data['R']
        for idx, r in enumerate(output_r):
            if r > 0:
                break
                
        cases_per_day = data['cases_per_day'] [:end_day+4]
        if data['serial_interval'] > 0:
            serial_interval = data['serial_interval']
        else:
            serial_interval = 7.0
            print("WARNING: serial_interval is 0")
        
        print(f"using serial interval :{serial_interval}")
        plotrt = PlotRt(R_T_MAX=4, sigma=0.25, GAMMA=1.0/serial_interval)
        most_likely, _ = plotrt.compute(cases_per_day, r0_estimate=2.5)
        Rt.append(most_likely[:end_day].tolist())

#         x = data['R'].copy()
#         x = np.array(x[idx:])
#         most_likely[:len(x)] = (x[:len(most_likely)] + most_likely[:len(x)]) / 2
#         Rt.append(most_likely)
    Rt = pd.DataFrame(Rt).transpose()
    ax = plot_line_with_bounds(Rt, ax, color, label, linestyle=":", marker="P", markersize=5)
    return ax

In [35]:
filenames_digital = [

     "../src/covid19sim/binary_digital_tracing/tracker_data_n_1000_seed_1005_20200516-195906_.pkl",
     "../src/covid19sim/binary_digital_tracing/tracker_data_n_1000_seed_1006_20200516-195907_.pkl",
     "../src/covid19sim/binary_digital_tracing/tracker_data_n_1000_seed_1007_20200516-195910_.pkl",
     "../src/covid19sim/binary_digital_tracing/tracker_data_n_1000_seed_1008_20200516-195907_.pkl",
     "../src/covid19sim/binary_digital_tracing/tracker_data_n_1000_seed_1009_20200516-195903_.pkl",
]

filenames_social_distancing = [
    # Same Effective Contacts
    "../src/covid19sim/social_distancing/same_effective_contacts/tracker_data_n_1000_seed_1005_20200517-073603_.pkl",
    "../src/covid19sim/social_distancing/same_effective_contacts/tracker_data_n_1000_seed_1006_20200517-073605_.pkl",
    "../src/covid19sim/social_distancing/same_effective_contacts/tracker_data_n_1000_seed_1007_20200517-073604_.pkl",
    "../src/covid19sim/social_distancing/same_effective_contacts/tracker_data_n_1000_seed_1008_20200517-073604_.pkl",
    "../src/covid19sim/social_distancing/same_effective_contacts/tracker_data_n_1000_seed_1009_20200517-073602_.pkl",
]

filenames_unmitigated = [
     "../src/covid19sim/unmitigated/tracker_data_n_1000_seed_1005_20200516-213201_.pkl",
     "../src/covid19sim/unmitigated/tracker_data_n_1000_seed_1006_20200516-213203_.pkl",
     "../src/covid19sim/unmitigated/tracker_data_n_1000_seed_1007_20200516-213201_.pkl",
     "../src/covid19sim/unmitigated/tracker_data_n_1000_seed_1008_20200516-213202_.pkl",
     "../src/covid19sim/unmitigated/tracker_data_n_1000_seed_1009_20200516-213201_.pkl",

]

filenames_transformer = [
     "../src/covid19sim/transformer/tracker_data_n_1000_seed_1005_20200516-224040_.pkl",
     "../src/covid19sim/transformer/tracker_data_n_1000_seed_1006_20200516-224056_.pkl",
     "../src/covid19sim/transformer/tracker_data_n_1000_seed_1007_20200516-230831_.pkl",
     "../src/covid19sim/transformer/tracker_data_n_1000_seed_1008_20200516-230819_.pkl",
     "../src/covid19sim/transformer/tracker_data_n_1000_seed_1009_20200516-230800_.pkl",
]

# filenames = [filenames_unmitigated, filenames_digital, filenames_transformer,  filenames_lockdown]
# labels = ["unmitigated", "digital", "transformer", "lockdown"]
filenames = [filenames_unmitigated, filenames_social_distancing, filenames_digital, filenames_transformer]
labels_fr = ["Déconfinement aveugle", "Distanciation sociale", "Dépistage binaire", "Dépistage IA"] 
title_fr = "Comparaison de méthodes de traçage (taux d'adoption de 60%)"
title_en ="Comparison of Tracing Methods (60% Adoption Rate)"
labels_en = ["Unmitigated", "Social distancing", "Binary contact tracing app", "AI-based app (COVI)"] 
fig = plot_all_in_french(filenames=filenames, labels=labels_fr, title=title_fr , end_day=19)

# fig
fig.savefig("covi_method_comparison_social_distancing_fr.pdf") 


  fig, (ax, rax) = plt.subplots(nrows=2, ncols=1, figsize=(15,14), sharex=True, dpi=500)


using serial interval :5.327083333333333
using serial interval :6.895833333333333
using serial interval :7.0
using serial interval :5.813715277777778
using serial interval :5.928819444444445
using serial interval :7.0
using serial interval :5.793749999999999
using serial interval :5.83611111111111
using serial interval :6.27037037037037
using serial interval :7.598611111111111
using serial interval :7.0
using serial interval :7.7027777777777775
using serial interval :5.618055555555555
using serial interval :7.125
using serial interval :5.9215277777777775
using serial interval :5.954861111111111
using serial interval :7.0
using serial interval :6.2444444444444445
using serial interval :5.38125
using serial interval :6.340277777777778


In [83]:
fig = plot_all(filenames=filenames, labels=labels_en, title=title_en , end_day=19)
#fig
fig.savefig("covi_method_comparison_social_distancing_en.pdf") 


  fig, (ax, rax) = plt.subplots(nrows=2, ncols=1, figsize=(15,14), sharex=True, dpi=500)


using serial interval :5.327083333333333
using serial interval :6.895833333333333
using serial interval :7.0
using serial interval :5.813715277777778
using serial interval :5.928819444444445
using serial interval :7.0
using serial interval :5.793749999999999
using serial interval :5.83611111111111
using serial interval :6.27037037037037
using serial interval :7.598611111111111
using serial interval :7.0
using serial interval :7.7027777777777775
using serial interval :5.618055555555555
using serial interval :7.125
using serial interval :5.9215277777777775
using serial interval :5.954861111111111
using serial interval :7.0
using serial interval :6.2444444444444445
using serial interval :5.38125
using serial interval :6.340277777777778
