In [36]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.signal import savgol_filter

In [37]:
from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter
from summer2.functions import get_piecewise_scalar_function

In [38]:
def pcwise_temp(x,domain,breakpts,vals):
    y = np.empty_like(x)
    num_intervals = len(breakpts)+1
    masks = []
    m = (domain[0] <= x) & (x <= breakpts[0])
    masks.append(m)
    for i in range(1,num_intervals-1):
        m = (breakpts[i-1] < x) & (x <= breakpts[i])
        masks.append(m)
    m = (breakpts[len(breakpts)-1] < x) & (x <= domain[1])
    masks.append(m)
    for j in range(0,num_intervals):
        y[masks[j]] = vals[j]
    return y

def pcwise_fcn(domain,breakpts,vals):
    def f(x):
        y = pcwise_temp(x,domain,breakpts,vals)
        return y 
    return f

In [39]:
#Plot piecewise function
domain = [0,10]
breakpts = [2,7]
vals = [1,2,3]
f = pcwise_fcn(domain,breakpts,vals)
xlist = np.linspace(domain[0], domain[1], num = 1000)
ylist = f(xlist)
fig = go.Figure()
fig = fig.add_trace(go.Scatter(x = xlist, y=ylist, name = "f"))
fig.show()

**Model 2: SEIQR (Stratified Infectious Compartment)**

In [40]:
def generate_transmission_parameter(num_breakpts):
    breakpts = []
    s = Parameter("len_pd"+str(1))
    breakpts.append(s)
    for i in range(2,num_breakpts+1):
        s += Parameter("len_pd"+str(i))
        breakpts.append(s)
    rates = []
    for k in range(1,num_breakpts+2):
        rates.append(Parameter("rate"+str(k)))
    return breakpts, rates

In [41]:
def generate_default_breakpts_and_rates(times,num_breakpts):
    time_length = times[1] - times[0]
    subinterval_length = time_length / (num_breakpts + 1)
    defp = {}
    for i in range(1,num_breakpts+2):
        defp["len_pd"+str(i)] = 0.9*subinterval_length
        defp["rate"+str(i)] = 0.2
    return defp

In [42]:
def generate_transmission_priors(times,num_breakpts):
    time_length = times[1] - times[0]
    subinterval_length = time_length / (num_breakpts + 1)
    priors = [esp.UniformPrior("rate1",(0,1.2))]
    for i in range(1,num_breakpts+1):
        priors.append(esp.UniformPrior("len_pd"+str(i),(0.25*subinterval_length,1.5*subinterval_length)))
        #prior for len_pd might need adjustments
        priors.append(esp.UniformPrior("rate"+str(i+1),(0,1.2)))
    return priors
    

In [43]:
def build_model(num_breakpts):
    #num_breakpts = number of breakpts for pcwise constant effective transmission rate parameter
    m = CompartmentalModel(
    times=[0, 268],
    compartments=["S", "E1", "E2", "E3", "E4", "I1", "I2", "I3", "I4", "Q1", "Q2", "Q3", "Q4", "Q2*", "Q3*", "Q4*", "R"],
    infectious_compartments=["I1", "I2", "I3", "I4"],
    timestep = 1,
    ref_date=datetime(2020,4,7)
    )
    m.set_initial_population({"S": 10353498.0, "E1": 360.0 , "I1": 544.0})
    #We set up a piecewise constant "effective contact rate"
    breakpts, rates = generate_transmission_parameter(num_breakpts)
    m.add_infection_frequency_flow("infection", get_piecewise_scalar_function(breakpts, rates),"S","E1") 
    #Progression rate for wild-type strain is 1/6.65 (Yu Wu et al.), which is multiplied by 4 because of the chained E compartmens
    m.add_transition_flow("progression1", 0.6,"E1","E2")
    m.add_transition_flow("progression2", 0.6,"E2","E3")
    m.add_transition_flow("progression3", 0.6,"E3","E4")
    m.add_transition_flow("progression4", 0.6,"E4","I1")
    #Detection can happen from any of the I compartments
    m.add_transition_flow("notification1", Parameter("detection_rate"),"I1","Q1")
    m.add_transition_flow("notification2", Parameter("detection_rate"),"I2","Q2")
    m.add_transition_flow("notification3", Parameter("detection_rate"),"I3","Q3")
    m.add_transition_flow("notification4", Parameter("detection_rate"),"I4","Q4")
    # We set up the ff. progressions:
    #   I1 -> I2 -> I3 -> I4 -> R
    #   (I1 ->) Q1 -> Q2* -> Q3* -> Q4* -> R
    #         (I2 ->) Q2 -> Q3* -> Q4* -> R
    #               (I3 ->) Q_3 -> Q_4* -> R
    #                     (I4 ->) Q_4 -> R
    # The progressions above are governed (ultimately) by a single parameter tau, which corresponds to the mean duration of infectiousness
    # Except for Q1 -> Q2*, Q2 -> Q3*, Q3 -> Q4*, Q4 -> R, the rate for each transition above is 4/tau (the 4 is due to the chained compartments structure).
    # To account for the detection rate, the rates for the flows Q1 -> Q2*, Q2 -> Q3*, Q3 -> Q4*, Q4 -> R are adjusted to maintain the same mean duration of infectiousness
    #Action point: Below is the assumption for now but this should be updated after checking relevant lit.
    tau = 14.0 
    m.add_transition_flow("I1_to_I2", 4/tau,"I1","I2") 
    m.add_transition_flow("I2_to_I3", 4/tau,"I2","I3") 
    m.add_transition_flow("I3_to_I4", 4/tau,"I3","I4") 
    m.add_transition_flow("I4_to_R", 4/tau,"I4","R")
    m.add_transition_flow("Q2*_to_Q3*", 4/tau,"Q2*","Q3*") 
    m.add_transition_flow("Q3*_to_Q4*", 4/tau,"Q3*","Q4*") 
    m.add_transition_flow("Q4*_to_R", 4/tau,"Q4*","R") 
    m.add_transition_flow("Q1_to_Q2*", 1/(tau - 1/Parameter("detection_rate")),"Q1","Q2*") 
    m.add_transition_flow("Q2_to_Q3*", 1/(tau - 1/Parameter("detection_rate")),"Q2","Q3*") 
    m.add_transition_flow("Q3_to_Q4*", 1/(tau - 1/Parameter("detection_rate")),"Q3","Q4*") 
    m.add_transition_flow("Q4_to_R", 1/(tau - 1/Parameter("detection_rate")),"Q4","R") 
    
    #No death dynamics for now since they are quite minimal at this point and data might not be accurate
    #Action point: Revisit this later to see if it needs updating    
    #m.add_death_flow("infection_death", Parameter("death_rate"), "I")
    #
    #Requesting Outputs
    m.request_output_for_flow("notification1", "notification1")
    m.request_output_for_flow("notification2", "notification2")
    m.request_output_for_flow("notification3", "notification3")
    m.request_output_for_flow("notification4", "notification4")
    m.request_aggregate_output(name = "notifications", sources=["notification1", "notification2", "notification3", "notification4"], save_results=True)
    #m.request_output_for_flow("infection_death", "infection_death")
    #m.request_cumulative_output(name="deaths_cumulative", source="infection_death")
    #
    #Action point: Incorporate age stratification  
    return m

In [44]:
num_breakpts = 38
m = build_model(num_breakpts)
m.get_input_parameters()


This method is deprecated and scheduled for removal, use get_piecewise_function instead



{'detection_rate',
 'len_pd1',
 'len_pd10',
 'len_pd11',
 'len_pd12',
 'len_pd13',
 'len_pd14',
 'len_pd2',
 'len_pd3',
 'len_pd4',
 'len_pd5',
 'len_pd6',
 'len_pd7',
 'len_pd8',
 'len_pd9',
 'rate1',
 'rate10',
 'rate11',
 'rate12',
 'rate13',
 'rate14',
 'rate15',
 'rate2',
 'rate3',
 'rate4',
 'rate5',
 'rate6',
 'rate7',
 'rate8',
 'rate9'}

In [45]:
defp_breakpts_rates = generate_default_breakpts_and_rates([m.times.min(),m.times.max()], num_breakpts)
parameters = {
    "rate1": 0.15, 
    "detection_rate": 0.1}
parameters.update(defp_breakpts_rates)

In [46]:
parameters

{'rate1': 0.2,
 'detection_rate': 0.1,
 'len_pd1': 16.080000000000002,
 'len_pd2': 16.080000000000002,
 'rate2': 0.2,
 'len_pd3': 16.080000000000002,
 'rate3': 0.2,
 'len_pd4': 16.080000000000002,
 'rate4': 0.2,
 'len_pd5': 16.080000000000002,
 'rate5': 0.2,
 'len_pd6': 16.080000000000002,
 'rate6': 0.2,
 'len_pd7': 16.080000000000002,
 'rate7': 0.2,
 'len_pd8': 16.080000000000002,
 'rate8': 0.2,
 'len_pd9': 16.080000000000002,
 'rate9': 0.2,
 'len_pd10': 16.080000000000002,
 'rate10': 0.2,
 'len_pd11': 16.080000000000002,
 'rate11': 0.2,
 'len_pd12': 16.080000000000002,
 'rate12': 0.2,
 'len_pd13': 16.080000000000002,
 'rate13': 0.2,
 'len_pd14': 16.080000000000002,
 'rate14': 0.2,
 'len_pd15': 16.080000000000002,
 'rate15': 0.2}

In [47]:
m.run(parameters)

In [48]:
fig = px.line(m.get_outputs_df())
fig.show()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



In [49]:
fig = px.line(m.get_derived_outputs_df())
fig.show()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



**Targets**

In [50]:
df=pd.read_excel(io='InfectionsData_060120_093020.xlsx',index_col=0)
notification_data = df["NOTIFICATIONS"]
#death_data = df["CUMULATIVE DEATHS"]

In [51]:
notification_smoothed = savgol_filter(notification_data, window_length=12, polyorder=2)
notification_smoothed = pd.DataFrame(notification_smoothed, index=notification_data.keys())
notification_smoothed = notification_smoothed.rename(columns={0:'smoothed_data'})
notification_smoothed = notification_smoothed['smoothed_data']

In [52]:
fig = px.line(notification_data)
fig = fig.add_trace(go.Scatter(x = notification_smoothed.keys(), y=notification_smoothed, name = "smoothed"))
fig.show()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



In [53]:
notification_target = notification_smoothed[31::2]
#death_data_cal = death_data[:90:2]

In [54]:
fig = px.scatter(notification_target)
fig.show()
#death_data_cal.plot(style='.')


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



**Optimization**

In [55]:
# Targets represent data we are trying to fit to
from estival import targets as est

# We specify parameters using (Bayesian) priors
from estival import priors as esp

# Finally we combine these with our summer2 model in a BayesianCompartmentalModel (BCM)
from estival.model import BayesianCompartmentalModel

In [56]:
# Import nevergrad
import nevergrad as ng

# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model

In [57]:
targets = [
    #est.NormalTarget("notification", notification_data_cal, np.std(notification_data_cal) * 0.1)
    est.TruncatedNormalTarget("notifications", notification_target, (0.0,np.inf),
        esp.UniformPrior("notification_dispersion",(0.1, notification_target.max()*0.01))),
    #est.NormalTarget("deaths_cumulative", death_data_cal, np.std(death_data_cal) * 0.1)
]

In [58]:
transmission_priors = generate_transmission_priors([m.times.min(),m.times.max()],num_breakpts)

priors = [
    esp.UniformPrior("detection_rate", (0,0.5)),
] + transmission_priors

In [59]:
priors

[UniformPrior detection_rate {bounds: (0.0, 0.5)},
 UniformPrior rate1 {bounds: (0.0, 1.2)},
 UniformPrior len_pd1 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate2 {bounds: (0.0, 1.2)},
 UniformPrior len_pd2 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate3 {bounds: (0.0, 1.2)},
 UniformPrior len_pd3 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate4 {bounds: (0.0, 1.2)},
 UniformPrior len_pd4 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate5 {bounds: (0.0, 1.2)},
 UniformPrior len_pd5 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate6 {bounds: (0.0, 1.2)},
 UniformPrior len_pd6 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate7 {bounds: (0.0, 1.2)},
 UniformPrior len_pd7 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate8 {bounds: (0.0, 1.2)},
 UniformPrior len_pd8 {bounds: (4.466666666666667, 26.800000000000004)},
 UniformPrior rate9 {bounds: 

In [60]:
defp = parameters

In [61]:
bcm = BayesianCompartmentalModel(m, defp, priors, targets)

In [62]:
# TwoPointsDE is a good suggested default for some of our typical use cases
opt_class = ng.optimizers.TwoPointsDE
orunner = optimize_model(bcm, opt_class=opt_class)

In [63]:
# You can also suggest starting points for the optimization (as well as specify an init method for unsuggested points)
# This is the "midpoint" method by default (ie the 0.5 ppf of the prior distribution)
orunner = optimize_model(bcm, opt_class=opt_class, suggested=defp, init_method="midpoint")

In [64]:
# Here we run the optimizer in a loop, inspecting the current best point at each iteration
# Using the loss information at each step can provide the basis for stopping conditions

for i in range(8):
    # Run the minimizer for a specified 'budget' (ie number of evaluations)
    rec = orunner.minimize(800)
    # Print the loss (objective function value) of the current recommended parameters
    print(rec.loss)

179.22246434610227
90.58937824601223
79.72691187182696
77.40447089174707
73.81778780405892
72.236517157241
72.13407470648522
71.95039891379554


In [65]:
mle_params = rec.value[1]
mle_params

{'detection_rate': 0.1849805620977825,
 'rate1': 0.35507553978457307,
 'len_pd1': 18.255506856156128,
 'rate2': 0.2528228780303829,
 'len_pd2': 25.430168306176174,
 'rate3': 0.15855110083435045,
 'len_pd3': 17.93723096801004,
 'rate4': 0.44742210379690484,
 'len_pd4': 9.473625184649173,
 'rate5': 0.31347721972627246,
 'len_pd5': 19.544555030270217,
 'rate6': 0.3574721524945594,
 'len_pd6': 19.927150189484074,
 'rate7': 0.2811021413330327,
 'len_pd7': 11.43900246393273,
 'rate8': 0.16561762321968745,
 'len_pd8': 17.837972118670084,
 'rate9': 0.19550098165446675,
 'len_pd9': 17.979573168300647,
 'rate10': 0.1545725401394826,
 'len_pd10': 12.750142586542298,
 'rate11': 0.17211439942667056,
 'len_pd11': 25.86398075936463,
 'rate12': 0.17991511303129631,
 'len_pd12': 22.774504649717258,
 'rate13': 0.20730008857792842,
 'len_pd13': 16.088196260273172,
 'rate14': 0.3018648256641086,
 'len_pd14': 13.756702484154493,
 'rate15': 0.13699138777487416,
 'notification_dispersion': 27.02499301473874}

In [66]:
# Run the model against the parameter estimates
res = bcm.run(mle_params)

In [67]:
#Generate breakpts and rates for mle
breakpts = []
s = mle_params["len_pd"+str(1)]
breakpts.append(s)
for i in range(2,num_breakpts+1):
    s += mle_params["len_pd"+str(i)]
    breakpts.append(s)
rates = []
for k in range(1,num_breakpts+2):
    rates.append(mle_params["rate"+str(k)])

print(breakpts)
print(rates)

#Plot "effective contact rate"
domain = [m.times.min(),m.times.max()]
dates = res.derived_outputs.notifications.keys()
f = pcwise_fcn(domain,breakpts,rates)
xlist = np.linspace(domain[0], domain[1], num = len(dates))
ylist = f(xlist)
fig = go.Figure()
fig = fig.add_trace(go.Scatter(x = dates, y=ylist, name = "eff. contact rate"))
fig.add_vrect(x0="2020-04-07", x1="2020-05-16", 
              annotation_text="ECQ", annotation_position="top left",
              fillcolor="aqua", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-05-16", x1="2020-06-01", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-06-01", x1="2020-08-03", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-03", x1="2020-08-19", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-19", x1="2020-12-31", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.show()

[18.255506856156128, 43.6856751623323, 61.62290613034234, 71.09653131499152, 90.64108634526174, 110.56823653474581, 122.00723899867853, 139.8452111173486, 157.82478428564926, 170.57492687219155, 196.4389076315562, 219.21341228127346, 235.30160854154664, 249.05831102570113]
[0.35507553978457307, 0.2528228780303829, 0.15855110083435045, 0.44742210379690484, 0.31347721972627246, 0.3574721524945594, 0.2811021413330327, 0.16561762321968745, 0.19550098165446675, 0.1545725401394826, 0.17211439942667056, 0.17991511303129631, 0.20730008857792842, 0.3018648256641086, 0.13699138777487416]


In [68]:
#Plot "effective contact rate" V2
target = "notifications"

domain = [m.times.min(),m.times.max()]
dates = res.derived_outputs.notifications.keys()
f = pcwise_fcn(domain,breakpts,rates)
xlist = np.linspace(domain[0], domain[1], num = len(dates))
ylist1 = f(xlist)

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig = fig.add_trace(go.Scatter(x = dates, y=ylist, name = "eff. contact rate"),secondary_y=False)

fig = fig.add_trace(go.Scatter(x = res.derived_outputs[target].keys(), y = res.derived_outputs[target], name = "MLE"), secondary_y=True)
fig = fig.add_trace(go.Scatter(x = bcm.targets[target].data.keys(), y = bcm.targets[target].data, name = "data", marker = dict(color='orange')), secondary_y=True)

fig.add_vrect(x0="2020-04-07", x1="2020-05-16", 
              annotation_text="ECQ", annotation_position="top left",
              fillcolor="aqua", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-05-16", x1="2020-06-01", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-06-01", x1="2020-08-03", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-03", x1="2020-08-19", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-19", x1="2020-12-31", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.show()

In [69]:
target = "notifications"

fig = go.Figure()
fig = fig.add_trace(go.Scatter(x = res.derived_outputs[target].keys(), y = res.derived_outputs[target], name = "MLE"))
fig = fig.add_trace(go.Scatter(x = bcm.targets[target].data.keys(), y = bcm.targets[target].data, name = "data", marker = dict(color='orange')))
fig.show()

In [70]:
# Inspect the bias of the resulting output

diff = (res.derived_outputs[target][::2] - bcm.targets[target].data)
print(diff.mean() / diff.std())
fig = go.Figure()
fig = fig.add_trace(go.Scatter(x = diff.keys(), y = diff, name = "diff"))
fig.show()

0.05510971428137481
