In [19]:
from bqplot import *
import bqplot.marks as bqm
import bqplot.scales as bqs
import bqplot.axes as bqa
import numpy as np
import bqplot as bq
from IPython.display import display
import ipywidgets as widgets
from scipy.integrate import odeint

print('ipywidgets version', widgets.__version__)
print('bqplot version', bq.__version__)

ipywidgets version 7.5.1
bqplot version 0.12.5


In [69]:
# use solve_ivp
from scipy.integrate import solve_ivp


# TODO: put dates on x-axis


# Total population, N.
N = 330e6
# Initial number of infected and recovered individuals, I0 and R0.
I0, R0 = 118, 0
# Everyone else, S0, is susceptible to infection initially.
S0 = N - I0 - R0
# Contact rate, beta, and mean recovery rate, gamma, (in 1/days).
beta, gamma = 0.3, 1./14 

# The SIR model differential equations.
def deriv(t, y, N, beta, gamma):
    S, I, R = y
    dSdt = -beta * S * I / N
    dIdt = beta * S * I / N - gamma * I
    dRdt = gamma * I
    return dSdt, dIdt, dRdt

# Initial conditions vector
y0 = S0, I0, R0
# Integrate the SIR equations over the time grid, t.
#ret = odeint(deriv, y0, t, args=(N, beta, gamma))

def update_plot(*args):
    gamma = 1/duration_slider.value
    beta = R0_slider.value * gamma
    ret = solve_ivp(deriv, [0, 300], y0, args=(N, beta, gamma), max_step=2)
    S, I, R = ret.y
    scat_I.y = I/1e6
    scat_S.y = S/1e6
    scat_R.y = R/1e6
    days = ret.t
    initial_date = np.datetime64('2020-03-04')
    dates = [initial_date + np.timedelta64(1, 'D') * t for t in days]
    scat_I.x = dates
    scat_S.x = dates
    scat_R.x = dates

R0_slider = widgets.FloatSlider(value=2.0, min=1.0, max=5.0, step=0.01, description='R0')
R0_slider.observe(update_plot, 'value')
duration_slider = widgets.FloatSlider(value=15, min=10, max=20, step=0.01, description='duration')
duration_slider.observe(update_plot, 'value')

sc_x = bqs.LinearScale()
sc_x = bqs.DateScale()
sc_y = bqs.LinearScale()

gamma = 1/duration_slider.value
beta = R0_slider.value * gamma

simulation = solve_ivp(deriv, [0,300], y0, args=(N, beta, gamma), max_step=2)
S, I, R = simulation.y
days = simulation.t
initial_date = np.datetime64('2020-03-04')
dates = [initial_date + np.timedelta64(1, 'D') * t for t in days]



scat_I = bqm.Lines(x=dates, y=I/1e6, scales={'x': sc_x, 'y': sc_y}, colors=['red'], labels=['Infected'], display_legend=True)
scat_S = bqm.Lines(x=dates, y=S/1e6, scales={'x': sc_x, 'y': sc_y}, colors=['blue'], labels=['Susceptible'], display_legend=True)
scat_R = bqm.Lines(x=dates, y=R/1e6, scales={'x': sc_x, 'y': sc_y}, colors=['green'], labels=['Removed'], display_legend=True)

fig = bq.Figure(marks=[scat_I, scat_S, scat_R], 
                axes=[bqa.Axis(scale=sc_x), 
                      bqa.Axis(scale=sc_y, label='Population (millions)', orientation='vertical')],
                legend_location='right')

box = widgets.VBox([R0_slider, duration_slider, fig])

display(box)


VBox(children=(FloatSlider(value=2.0, description='R0', max=5.0, min=1.0, step=0.01), FloatSlider(value=15.0, …