In [None]:
import numpy as np
import panel as pn

PRIMARY_COLOR = "#0072B5"
SECONDARY_COLOR = "#B54300"

pn.extension(design="material", sizing_mode="stretch_width")

In [None]:
from settlements import parse_settlements
from spatial_sim import Params, init_state, step_state

In [None]:
@pn.cache
def get_data():
    return parse_settlements()

settlements_df = get_data()

settlements_df.head()

In [None]:
def reset_params():

    return Params(
        beta=32, seasonality=0.16, demog_scale=1.0, 
        mixing_scale=0.002, distance_exponent=1.5)

params = reset_params()

In [None]:
def reset_state():

    return init_state(settlements_df, params)

state = reset_state()

In [None]:
beta_slider = pn.widgets.FloatSlider(value=params.beta, start=0, end=50, step=1, name='beta')
def on_beta_change(value):
    params.beta = value
bound_beta = pn.bind(on_beta_change, value=beta_slider)

seasonality_slider = pn.widgets.FloatSlider(value=params.seasonality, start=0, end=0.3, step=0.02, name='seasonality')
def on_seasonality_change(value):
    params.seasonality = value
bound_seasonality = pn.bind(on_seasonality_change, value=seasonality_slider)

demog_scale_slider = pn.widgets.FloatSlider(value=params.demog_scale, start=0.1, end=1.5, step=0.05, name='demog_scale')
def on_demog_scale_change(value):
    params.demog_scale = value
    params.biweek_avg_births = params.demog_scale * params.births / 26.
    params.biweek_death_prob = params.demog_scale * params.births / params.population / 26.
bound_demog_scale = pn.bind(on_demog_scale_change, value=demog_scale_slider)

In [None]:
from mixing import init_gravity_diffusion

mixing_scale_slider = pn.widgets.FloatSlider(value=np.log10(params.mixing_scale), start=-4, end=-2, name='log10(mixing_scale)')
def on_mixing_scale_change(value):
    params.mixing_scale = np.power(10, value)
    params.mixing = init_gravity_diffusion(settlements_df, params.mixing_scale, params.distance_exponent)
bound_mixing_scale = pn.bind(on_mixing_scale_change, value=mixing_scale_slider)

distance_exponent_slider = pn.widgets.FloatSlider(value=params.distance_exponent, start=0.5, end=2.5, step=0.1, name='distance_exponent')
def on_distance_exponent_change(value):
    params.distance_exponent = value
    params.mixing = init_gravity_diffusion(settlements_df, params.mixing_scale, params.distance_exponent)
bound_distance_exponent = pn.bind(on_distance_exponent_change, value=distance_exponent_slider)

In [None]:
from bokeh import models, plotting, io
from bokeh.palettes import Reds256

source = models.ColumnDataSource(dict(
    name=settlements_df.index,
    x=settlements_df.Long, 
    y=settlements_df.Lat,
    size=0.03*np.sqrt(settlements_df.population),
    prevalence=state[:, 1] / state[:, :].sum(axis=-1)
))

exp_cmap = models.LogColorMapper(palette=Reds256[::-1], low=1e-4, high=0.01)

p = plotting.figure(
    x_axis_label="Longitude", y_axis_label="Latitude",
    title="Prevalence",
)
p.scatter(x="x", y="y", size="size", color={"field": "prevalence", "transform": exp_cmap}, source=source, alpha=0.5)

io.curdoc().add_root(p)

def stream():
    step_state(state, params)
    p.title.text = "Prevalence (year = {:.2f})".format(state.t/26.)
    source.data["prevalence"] = state[:, 1] / state[:, :].sum(axis=-1)

callback_period = 50
callback = pn.state.add_periodic_callback(stream, callback_period)

speed_slider = pn.widgets.FloatSlider(value=callback_period, start=10, end=200, step=10, name='refresh rate (ms)')
def on_speed_change(value):
    callback.period = value
bound_speed = pn.bind(on_speed_change, value=speed_slider)

reset_button = pn.widgets.Button(name='Reset', button_type='primary')
def reset(event):
    global params, state
    params = reset_params()
    state = reset_state()
    beta_slider.value = params.beta
    seasonality_slider.value = params.seasonality
    demog_scale_slider.value = params.demog_scale
    mixing_scale_slider.value = np.log10(params.mixing_scale)
    speed_slider.value = callback_period
    if not callback.running:
        callback.start()
reset_button.on_click(reset)

pause_button = pn.widgets.Toggle(name='Pause/Resume', value=True)
pause_button.link(callback, bidirectional=True, value='running')

In [None]:
sliders = pn.Column(
    "### Simulation parameters",
    pn.Row(beta_slider, bound_beta),
    pn.Row(seasonality_slider, bound_seasonality),
    pn.Row(demog_scale_slider, bound_demog_scale),
    pn.layout.Divider(),
    "### Mixing parameters",
    pn.Row(mixing_scale_slider, bound_mixing_scale),
    pn.Row(distance_exponent_slider, bound_distance_exponent),
    pn.layout.Divider(),
    "### Playback controls",
    pn.Row(speed_slider, bound_speed),
    pn.Row(reset_button, pause_button)
)

In [None]:
pn.template.MaterialTemplate(
    site="numpy demo",
    title="Interactive Spatial Simulation",
    sidebar=[sliders],
    main=[p],
).servable();  # The ; is needed in the notebook to not display the template

In [None]:
# In a terminal run: panel serve interactive.ipynb --autoreload
# Navigate to served site at: http://localhost:[port]/interactive