In [78]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.signal import savgol_filter

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.functions import time as stf
pd.options.plotting.backend = "plotly"

In [79]:
from main import model
from main import helper
from main import target_data

In [80]:
from importlib import reload
reload(model)
reload(helper)
reload(target_data)

<module 'main.target_data' from '/Users/mark/dev/DFAT/organized/main/target_data.py'>

In [81]:
#Build model
num_breakpts = 'fixed1' #integer or 'fixed1'
transmission_modifier_mode = 'sigmoidal' #pcwise_constant OR sigmoidal OR linear_interp
m = model.build_model(num_breakpts, transmission_modifier_mode)

In [82]:
#Run model with default parameters
defp = model.generate_default_parameters(m,num_breakpts,transmission_modifier_mode)
m.run(defp)

In [83]:
#Import notification target
notification_target = target_data.import_notification_target()

**Optimization**

In [84]:
from estival import targets as est
from estival import priors as esp
from estival.model import BayesianCompartmentalModel

In [85]:
import nevergrad as ng
from estival.wrappers.nevergrad import optimize_model

In [86]:
targets = [
    est.TruncatedNormalTarget("notifications", notification_target, (0.0,np.inf),
        esp.UniformPrior("notification_dispersion",(0.05, 5))),
]

In [87]:
tau = 14.0
transmission_priors = helper.generate_transmission_priors(num_breakpts)

priors = [
    esp.UniformPrior("detection_rate", (4/tau,0.7))
] + transmission_priors

In [88]:
bcm = BayesianCompartmentalModel(m, defp, priors, targets)

In [89]:
# TwoPointsDE is a good suggested default for some of our typical use cases
opt_class = ng.optimizers.TwoPointsDE
orunner = optimize_model(bcm, opt_class=opt_class)

In [90]:
# You can also suggest starting points for the optimization (as well as specify an init method for unsuggested points)
# This is the "midpoint" method by default (ie the 0.5 ppf of the prior distribution)
orunner = optimize_model(bcm, opt_class=opt_class, suggested=defp, init_method="midpoint")

In [91]:
# Here we run the optimizer in a loop, inspecting the current best point at each iteration
# Using the loss information at each step can provide the basis for stopping conditions

for i in range(8):
    # Run the minimizer for a specified 'budget' (ie number of evaluations)
    rec = orunner.minimize(2000)
    # Print the loss (objective function value) of the current recommended parameters
    print(rec.loss)

1159.8414333899907
910.8241970691647
695.2821997202974
598.3842451864356
556.777733660964
540.221345724144
534.8589921576107
531.8106113826243


In [92]:
mle_params = rec.value[1]
mle_params

{'detection_rate': 0.3640459569335504,
 'val0': 0.5963468238097261,
 'val1': 0.34018566568950803,
 'val2': 0.1998418649887029,
 'val3': 0.8414727511010768,
 'val4': 0.49985240194744507,
 'val5': 0.5463744622765456,
 'val6': 0.518881050205982,
 'val7': 0.25478267444362873,
 'val8': 0.36079962675704963,
 'val9': 0.2860222472564566,
 'val10': 0.2937521171819103,
 'val11': 0.38121673799862194,
 'val12': 0.3732658873338873,
 'val13': 0.3784028996929435,
 'val14': 0.6134187375828795,
 'notification_dispersion': 4.972517709019251}

In [93]:
# Run the model against the parameter estimates
res = bcm.run(mle_params)

In [94]:
epoch = m.get_epoch()
epoch

Epoch from 2020-04-07 00:00:00, in units of 1 day, 0:00:00

In [95]:
epoch.datetime_to_number(datetime(2020,12,31))

268.0

In [96]:
#Plot "transmission modifier"
domain = [m.times.min(),m.times.max()]
dates = res.derived_outputs.notifications.keys()
if transmission_modifier_mode == 'pcwise_constant':
    f = helper.get_pcwise_transmission_modifier(mle_params,num_breakpts)
elif transmission_modifier_mode == 'sigmoidal':
    f = helper.get_sigmoidal_transmission_modifier([m.times.min(),m.times.max()],mle_params,num_breakpts)
else:
    f = helper.get_linear_interp_transmission_modifier([m.times.min(),m.times.max()],mle_params,num_breakpts)
xlist = np.linspace(domain[0], domain[1], num = len(dates))
ylist = f(xlist)
fig = go.Figure()
fig = fig.add_trace(go.Scatter(x = dates, y=ylist, name = "transmission modifier"))
fig.add_vrect(x0="2020-04-07", x1="2020-05-16", 
              annotation_text="ECQ", annotation_position="top left",
              fillcolor="aqua", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-05-16", x1="2020-06-01", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-06-01", x1="2020-08-03", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-03", x1="2020-08-19", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-19", x1="2020-12-31", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.show()

In [97]:
#Plot "effective contact rate" V2
target = "notifications"

domain = [m.times.min(),m.times.max()]
dates = res.derived_outputs.notifications.keys()
xlist = np.linspace(domain[0], domain[1], num = len(dates))
ylist1 = f(xlist)

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig = fig.add_trace(go.Scatter(x = dates, y=ylist, name = "eff. contact rate"),secondary_y=False)

fig = fig.add_trace(go.Scatter(x = res.derived_outputs[target].keys(), y = res.derived_outputs[target], name = "MLE"), secondary_y=True)
fig = fig.add_trace(go.Scatter(x = bcm.targets[target].data.keys(), y = bcm.targets[target].data, name = "data", marker = dict(color='orange')), secondary_y=True)

fig.add_vrect(x0="2020-04-07", x1="2020-05-16", 
              annotation_text="ECQ", annotation_position="top left",
              fillcolor="aqua", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-05-16", x1="2020-06-01", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-06-01", x1="2020-08-03", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-03", x1="2020-08-19", 
              annotation_text="MECQ", annotation_position="top left",
              fillcolor="pink", opacity=0.25, line_width=0)
fig.add_vrect(x0="2020-08-19", x1="2020-12-31", 
              annotation_text="GCQ", annotation_position="top left",
              fillcolor="lightgreen", opacity=0.25, line_width=0)
fig.show()

**Model inspections**

In [98]:
#Generate test parameters (for sigmoidal or linear_interp), different from default params.
test_param = {
    'val0': 0.1,
    'val1': 0.5,
    'val2': 0.9,
    'val3': 0.3,
    'val4': 0.2,
    'val5': 0.9,
    'val6': 0.5,
    'detection_rate': 0.3
}
m.run(test_param)

GraphRunError: ('parameters.val7', KeyError('val7'))

In [None]:
#Plot compartment outputs
m.get_outputs_df().plot()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



In [None]:
#Plot notifications
m.get_derived_outputs_df().plot()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



In [None]:
#Plot transmission modifier (uncomment one of the three below)
#F = helper.get_pcwise_transmission_modifier(test_param,num_breakpts)
F = helper.get_sigmoidal_transmission_modifier([m.times.min(),m.times.max()],defp,num_breakpts)
#F = helper.get_linear_interp_transmission_modifier([m.times.min(),m.times.max()],test_param,num_breakpts)
#
# Plot transmission modifier
helper.plot_transmission_modifier(m,F)

In [None]:
#Plot notification target
notification_target.plot()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result

