In [2]:
from jupyter_plotly_dash import JupyterDash

import numpy as np
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import sys
sys.path.append("../sim_jobs/worker")
from corona_sim import runPyGame, desease_color, desease_state, Corona_Simulation

plotly_desease_color = {}
for key in desease_color:
    col = desease_color[key]
    plotly_desease_color[key] = "rgb(%i, %i, %i)" % col

simu_fig = go.Figure(data=[go.Scatter(mode="markers", x=[], y=[], name=key,
                                 marker=dict(color=plotly_desease_color[desease_state[key]])) 
                      for key in desease_state.keys()])

stats_fig = make_subplots(specs=[[{"secondary_y": True}]])

stats_fig.add_trace(go.Scatter(x=[], y=[], name="#sick agents"), secondary_y=False)
stats_fig.add_trace(go.Scatter(x=[], y=[], name="Social distancing active"), secondary_y=True)
        
simulation = None
    
app = JupyterDash('SimpleExample')

app.layout = html.Div([
    dcc.Graph(id="simulation-figure", config=dict(editable=True), figure=simu_fig),
    dcc.Graph(id="stats-figure", config=dict(editable=True), figure=stats_fig),
    html.Button('Start Simulation', id='button'),
    html.P(id="info_text", children="Test"),
    dcc.Interval(
            id='interval-component',
            interval=1*250, # in milliseconds
            n_intervals=0
        )
])

def update_graphs2(sim):
    if fig.layout.shapes == ():
        shapes = [dict(type="circle", xref="x",yref="y",
                   x0=center[0]-sim.agent_radius, y0=center[1]-sim.agent_radius,
                   x1=center[0]+sim.agent_radius, y1=center[1]+sim.agent_radius,
                   fillcolor=plotly_desease_color[sim.desease_state.deseases[anum]])
              
              for anum, center in enumerate(sim.centers[:,:])]
    
        fig.update_layout(shapes=shapes)
    else:
        for anum, center in enumerate(sim.centers[:,:]):
            try:
                fig.layout.shapes[anum].x0 = center[0]-sim.agent_radius
                fig.layout.shapes[anum].y0 = center[1]-sim.agent_radius
                fig.layout.shapes[anum].x1 = center[0]+sim.agent_radius
                fig.layout.shapes[anum].y1 = center[1]+sim.agent_radius
                fig.layout.shapes[anum].fillcolor = plotly_desease_color[sim.desease_state.deseases[anum]]
            except:
                raise Exception(f"{anum}, {fig.layout.shapes}, {center}")
        fig.update_layout()
                
def update_graphs(sim):
    x = [[] for key in desease_state]
    y = [[] for key in desease_state]
    
    for anum, center in enumerate(sim.centers[:,::-1]):
        d = int(sim.desease_state.deseases[anum])
        x[d-1].append(center[0])
        y[d-1].append(center[1])
        #pygame.draw.circle(screen, desease_color[self.desease_state.deseases[anum]], center, int(self.agent_radius))
    
    for i in range(len(x)):
        simu_fig.data[i-1].x = x[i-1]
        simu_fig.data[i-1].y = y[i-1]
        
    with stats_fig.batch_update():
        if stats_fig.data[0].y is None:
            stats_fig.data[0].x = [sim.num_ticks/sim.fps]
            stats_fig.data[1].x = [sim.num_ticks/sim.fps]
            stats_fig.data[0].y = [sim.desease_state.sim_sick]
            stats_fig.data[1].y = [sim.desease_state.sim_R_spread] #social_state.sd_active]
        else:
            stats_fig.data[0].x = stats_fig.data[0].x + (sim.num_ticks/sim.fps,)
            stats_fig.data[1].x = stats_fig.data[1].x + (sim.num_ticks/sim.fps,)
            stats_fig.data[0].y = stats_fig.data[0].y + (sim.desease_state.sim_sick,)
            stats_fig.data[1].y = stats_fig.data[1].y + (sim.desease_state.sim_R_spread,)#social_state.sd_active,)
    stats_fig.update_layout()

@app.callback(
    Output('info_text', 'children'),
    [Input('button', 'n_clicks')])
def update_output(n_clicks):
    global simulation
    social_conf = dict(sd_impact=0.9, sd_start=0.05, sd_stop=0.01, sd_recovered=True, know_rate_sick=0.75,
                   know_rate_recovered=0.1, party_freq=7, party_R_boost=3)

    desease_conf = dict(R_spread=2.1, desease_duration=8, fatality=0.14, initial_sick=5)
    
    if n_clicks != None:
        simulation = Corona_Simulation(**dict(num_agents=500, height=600, width=600, fps=10,
                                             social_conf=social_conf, desease_conf=desease_conf,
                                             agent_radius = 5, run=None, sim_md5=None))
    return ("%s" % n_clicks)

@app.callback(
    [Output("simulation-figure", "figure"),Output("stats-figure", "figure")],
    [Input("interval-component", "n_intervals")])

def update_graph(n_intervals):
    simu_fig.update_xaxes(range=[0, 600])
    simu_fig.update_yaxes(range=[0, 600])

    if simulation != None:
        if not simulation.is_finished():
            simulation.update()
            update_graphs(simulation)

    return simu_fig, stats_fig

app