In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LogNorm
import numpy as np

In [None]:
from settlements import parse_settlements
from spatial_sim import Params, init_state, step_state

In [None]:
params = Params(
    beta=32, seasonality=0.16, demog_scale=1.0, 
    mixing_scale=0.002, distance_exponent=1.5)

In [None]:
settlements_df = parse_settlements()
state = init_state(settlements_df, params)
t = 0

In [None]:
from ipywidgets import FloatSlider

beta_slider = FloatSlider(value=params.beta, min=0, max=50, step=1, description='beta')
def on_beta_change(v):
    params.beta = v["new"]
beta_slider.observe(on_beta_change, names='value')

seasonality_slider = FloatSlider(value=params.seasonality, min=0, max=0.3, step=0.02, description='seasonality')
def on_seaonality_change(v):
    params.seasonality = v["new"]
seasonality_slider.observe(on_seaonality_change, names='value')

demog_scale_slider = FloatSlider(value=params.demog_scale, min=0.1, max=1.5, step=0.05, description='demog_scale')
def on_demog_scale_change(v):
    params.demog_scale = v["new"]
    params.biweek_avg_births *= v["new"] / v["old"]
    params.biweek_death_prob *= v["new"] / v["old"]
demog_scale_slider.observe(on_demog_scale_change, names='value')

In [None]:
# "widget" for vscode
%matplotlib widget

# "notebook" for jupyter notebook (+ add a plt.show() after animation.FuncAnimation function call)
# %matplotlib notebook

fig, ax = plt.subplots()

scat = ax.scatter(
    settlements_df.Long, 
    settlements_df.Lat, 
    s=0.1*np.sqrt(settlements_df.population), 
    c=state[:, 1] / state[:, :].sum(axis=-1), 
    cmap="Reds", norm=LogNorm(vmin=1e-4, vmax=0.01), alpha=0.5)

paused = False

def simulate_step():
    global paused
    while not paused:
        yield step_state(state, params, t)

def animate(i):
    next(simulate_step())
    ax.set_title("{:.2f} years".format(i/26.))
    scat.set_array(state[:, 1] / state[:, :].sum(axis=-1))
    return scat,

ani = animation.FuncAnimation(fig, animate, cache_frame_data=True, interval=50, blit=False)

def onClick(event):
    global paused
    paused ^= True

fig.canvas.mpl_connect('button_press_event', onClick)

In [None]:
display(beta_slider)
display(seasonality_slider)
display(demog_scale_slider)