In [1]:
# libraries
import sys
sys.path.append("../models/")
from seird_model import integrate_SEIRD
from functions import load_epi_params, import_country, get_totR
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from pyabc import (ABCSMC,
                   RV, Distribution,
                   MedianEpsilon,
                   LocalTransition)
import pyabc
import matplotlib.pyplot as plt
import os
import tempfile
import scipy as sp
import pickle as pkl
from scipy.stats import pearsonr

# fit period
start_date = datetime(2020, 9, 1)
end_date   = datetime(2020, 12, 1)  # excluded

# get epi params
IFR, mu, eps, Delta = load_epi_params("../../data/")

# ABC-SMC Calibration

In [8]:
def abc_calibration(IFR, mu, eps, Delta, country_dict, start_date, end_date):
    
    """
    This function run the ABC-SMC algorithm to fit the model's parameters 
    for a given country.
        :param IFR (array): age-stratified infection fatality rate
        :param mu (float): recovery rate
        :param eps (float): incubation rate
        :param country_dict (dict): dict with country data
        :param start_date (datetime): starting date
        :param end_date (datetime): end date
        :return: returns ABC-SMC fit object
    """

    
    ############# FUNCTIONS #############
    def metric(simulation, data):
        
        """
        This function computes the error metric as: 
        
                s(E, E') = (1 - rho(Inc., Inc.')) + abs(Deaths_tot - Deaths_tot') / Deaths_tot'
                
            :param simulation (dict): simulated data
            :param data (dict): real data
            :return: returns error metric
        """
        
        err_D = np.abs(simulation["deaths"].sum() - data["deaths"].sum()) / data["deaths"].sum()
        err_C = 1 - pearsonr(simulation["cases"], data["cases"])[0]
        return err_D + err_C 


    def model(params):

        """
        Model Definition
            :param params (dict): parameters with a prior distribution
            :return: returns simulated data
        """

        # integrate 
        solution, dates = integrate_SEIRD(0.0, i0, r0,
                                          start_date, (end_date - start_date).days , 
                                          params["R0"], eps, mu, IFR, Delta,
                                          country_dict, 
                                          [params["w1"], params["w2"], params["w3"]],
                                          [params["o1"], params["o2"], params["o3"]], 
                                          [params["s1"], params["s2"], params["s3"]])
        
        # compute weekly deaths
        deaths = np.diff(solution.sum(axis=0)[5], prepend=0)  # daily deaths
        deaths_w, i, s = [], 0, 7
        while i + s < deaths.shape[0]:
            deaths_w.append(deaths[i:i+s].sum())
            i += s

        # compute weekly incidence 
        incidence = get_incidence(solution)
        incidence_w, i, s = [], 0, 7
        while i + s < incidence.shape[0]:
            incidence_w.append(incidence[i:i+s].sum())
            i += s
            
        return {"deaths": np.array(deaths_w), "cases": np.array(incidence_w), "solution": solution}
    
    
    def get_incidence(solution):
    
        """
        This function compure the daily incidence
            :param solution (matrix): SEIRD solution
            :return: returns daily incidence
        """

        # sum over age
        solution_age = solution.sum(axis=0)

        # compute incidence
        incidence = [0]
        for i in np.arange(1, solution_age.shape[1]):
            DeltaS  = solution_age[0][i] - solution_age[0][i-1]
            DeltaE  = solution_age[1][i] - solution_age[1][i-1]
            incidence.append(-(DeltaS + DeltaE))

        return np.array(incidence)
    #####################################
    
    
    ############### DATA ################
    # observed incidence (compute weekly)
    cases_real = country_dict["epi_data"].loc[(country_dict["epi_data"].index>=start_date) & (country_dict["epi_data"].index<end_date)]["cases"].values
    cases_real_w, i, s = [], 0, 7
    while i + s < cases_real.shape[0]:
        cases_real_w.append(cases_real[i:i+s].sum())
        i += s
    cases_real_w = np.array(cases_real_w)
    
    # observed deaths (compute weekly)
    deaths_real = country_dict["epi_data"].loc[(country_dict["epi_data"].index>=start_date) & (country_dict["epi_data"].index<end_date)]["deaths"].values
    deaths_real_w, i, s = [], 0, 7
    while i + s < deaths_real.shape[0]:
        deaths_real_w.append(deaths_real[i:i+s].sum())
        i += s
    deaths_real_w = np.array(deaths_real_w)
    ####################################
    
    
    ############## PRIORS ##############
    db_path = ("sqlite:///" + os.path.join(tempfile.gettempdir(), "test.db"))
    
    # new positive of previous week define I(t=0) / N
    new_pos = country_dict["epi_data"].loc[(country_dict["epi_data"].index>=start_date - timedelta(days=7)) & (country_dict["epi_data"].index<start_date)]["cases"].sum()    
    i0 = new_pos / country_dict["Nk"].sum()
    
    # R(t=0) / N
    r0 = get_totR("../../data/", start_date, country_dict["country"]) / country_dict["Nk"].sum()
    
    # define the priors
    parameter_prior = Distribution(R0=RV("uniform", 1.1, 2.0 - 1.1),
                                   w1=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   w2=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   w3=RV('rv_discrete', values=([25, 50, 75], [1/3] * 3)),
                                   o1=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   o2=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   o3=RV('rv_discrete', values=([25, 50, 75], [1/3] * 3)),
                                   s1=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   s2=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)),
                                   s3=RV('rv_discrete', values=([25, 50, 75, 100], [1/4] * 4)))
    
    # define transitions
    transition = pyabc.AggregatedTransition(mapping={
        'w1': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        'w2': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        'w3': pyabc.DiscreteJumpTransition(domain=[25, 50, 75], p_stay=0.7),
        'o1': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        'o2': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        'o3': pyabc.DiscreteJumpTransition(domain=[25, 50, 75], p_stay=0.7),
        's1': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        's2': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        's3': pyabc.DiscreteJumpTransition(domain=[25, 50, 75, 100], p_stay=0.7),
        'R0': LocalTransition(k_fraction=.3)})
    ##################################
    
    
    # abc calibration
    abc = ABCSMC(models=model,
                 parameter_priors=parameter_prior,
                 distance_function=metric,
                 population_size=100,
                 transitions=transition,
                 eps=MedianEpsilon(1.5, median_multiplier=0.8))
    abc.new(db_path, {"deaths": deaths_real_w, "cases": cases_real_w})
    
    # run
    h = abc.run(minimum_epsilon=0.05, 
                max_nr_populations=5, 
                max_total_nr_simulations=7000, 
                max_walltime=timedelta(hours=3))
    
    # save fit results
    for t in [0,1,2,3,4]:
        with open("./abc_fit/" + country_dict["country"] + "_distribution_t" + str(t) + ".pkl", "wb") as file: 
            pkl.dump(h.get_distribution(t=t), file)
            
        with open("./abc_fit/" + country_dict["country"] + "_sumstats_t" + str(t) + ".pkl", "wb") as file:
            pkl.dump(h.get_weighted_sum_stats(t=t), file)
            
    return h

Fit:


- Italy

In [9]:
country = "Italy"

# import country data
country_dict = import_country(country, "../../data/countries/")

# calibrate
h = abc_calibration(IFR, mu, eps, Delta, country_dict, start_date, end_date)

INFO:Sampler:Parallelizing the sampling on 4 cores.
INFO:History:Start <ABCSMC(id=9, start_time=2020-12-12 11:27:06.487713, end_time=None)>
INFO:ABC:t: 0, eps: 1.5.
INFO:ABC:Acceptance rate: 100 / 128 = 7.8125e-01, ESS=1.0000e+02.
INFO:ABC:t: 1, eps: 0.6835819986475865.
INFO:ABC:Acceptance rate: 100 / 373 = 2.6810e-01, ESS=8.6012e+01.
INFO:ABC:t: 2, eps: 0.31389745176414907.
INFO:ABC:Acceptance rate: 100 / 837 = 1.1947e-01, ESS=7.6130e+01.
INFO:ABC:t: 3, eps: 0.15910603024761116.
INFO:ABC:Acceptance rate: 100 / 2185 = 4.5767e-02, ESS=7.3660e+01.
INFO:ABC:t: 4, eps: 0.07143752371471922.
INFO:ABC:Acceptance rate: 100 / 6129 = 1.6316e-02, ESS=7.6138e+01.
INFO:ABC:Stopping: total simulations budget.
INFO:History:Done <ABCSMC(id=9, start_time=2020-12-12 11:27:06.487713, end_time=2020-12-12 12:41:58.513406)>


- France

In [6]:
country = "France"

# import country data
country_dict = import_country(country, "../../data/countries/")

# calibrate
h = abc_calibration(IFR, mu, eps, Delta, country_dict, start_date, end_date)

INFO:Sampler:Parallelizing the sampling on 4 cores.
INFO:History:Start <ABCSMC(id=7, start_time=2020-12-12 00:18:53.666740, end_time=None)>
INFO:ABC:t: 0, eps: 1.5.
INFO:ABC:Acceptance rate: 100 / 190 = 5.2632e-01, ESS=1.0000e+02.
INFO:ABC:t: 1, eps: 0.6373717600079074.
INFO:ABC:Acceptance rate: 100 / 560 = 1.7857e-01, ESS=7.8551e+01.
INFO:ABC:t: 2, eps: 0.3762223123795359.
INFO:ABC:Acceptance rate: 100 / 964 = 1.0373e-01, ESS=8.1794e+01.
INFO:ABC:t: 3, eps: 0.2195912302942111.
INFO:ABC:Acceptance rate: 100 / 1608 = 6.2189e-02, ESS=5.6778e+01.
INFO:ABC:t: 4, eps: 0.1270414439716301.
INFO:ABC:Acceptance rate: 100 / 4939 = 2.0247e-02, ESS=5.4158e+01.
INFO:ABC:Stopping: total simulations budget.
INFO:History:Done <ABCSMC(id=7, start_time=2020-12-12 00:18:53.666740, end_time=2020-12-12 01:15:18.430805)>


- United Kingdom

In [7]:
country = "United_Kingdom"

# import country data
country_dict = import_country(country, "../../data/countries/")

# calibrate
h = abc_calibration(IFR, mu, eps, Delta, country_dict, start_date, end_date)

INFO:Sampler:Parallelizing the sampling on 4 cores.
INFO:History:Start <ABCSMC(id=8, start_time=2020-12-12 01:15:20.081604, end_time=None)>
INFO:ABC:t: 0, eps: 1.5.
INFO:ABC:Acceptance rate: 100 / 143 = 6.9930e-01, ESS=1.0000e+02.
INFO:ABC:t: 1, eps: 0.7212727524780076.
INFO:ABC:Acceptance rate: 100 / 539 = 1.8553e-01, ESS=6.2182e+01.
INFO:ABC:t: 2, eps: 0.4431880409312303.
INFO:ABC:Acceptance rate: 100 / 786 = 1.2723e-01, ESS=6.3503e+01.
INFO:ABC:t: 3, eps: 0.2684127883923822.
INFO:ABC:Acceptance rate: 100 / 2195 = 4.5558e-02, ESS=7.7120e+01.
INFO:ABC:t: 4, eps: 0.17327750492507699.
INFO:ABC:Acceptance rate: 100 / 4541 = 2.2022e-02, ESS=5.7760e+01.
INFO:ABC:Stopping: total simulations budget.
INFO:History:Done <ABCSMC(id=8, start_time=2020-12-12 01:15:20.081604, end_time=2020-12-12 02:12:09.808246)>
