In [None]:
import ipywidgets as widgets
import logic_iterative as logic
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import json
import time

logic.Globals.set(0, 0)

def getCharts() -> go.FigureWidget:
    with open("assets/" + logic.Globals.scale + "/geo.json", encoding='utf8') as file:
        geojson = json.load(file)

    colors = px.colors.qualitative.Vivid

    frame = logic.State.getDummyDataFrame()

    fig = make_subplots(rows=1, cols=2, column_widths=[2, 1], specs=[[{'type': 'choropleth'}, {'type': 'pie'}]])
    fig.add_choropleth(
        geojson=geojson,
        locations=frame['code'], 
        z=frame['district'],
        text=frame['metric'],
        zmin=0, zmax=len(colors),
        legendgroup="group",
        colorscale=colors,
        row=1, col=1
    )
    fig.update_layout(
        geo={'scope': 'usa'},
        height=700
    )
    fig.add_pie(
        values=frame['metric'],
        labels=frame['district'],
        textinfo="label+value+percent",
        legendgroup="group",
        marker=dict(colors=colors),
        row=1, col=2
    )
    return go.FigureWidget(fig)
plots = getCharts()

def getButtons():
    return { 'advance': widgets.Button(
        description="Advance",
        disabled = True
    ), 'solve': widgets.Button(
        description="Solve"
    ), 'scale': widgets.Combobox(
        value=logic.Globals.scale,
        options=logic.Globals.scales,
        description='Scale:',
        ensure_option=True,
        disabled=False
    ), 'count': widgets.BoundedIntText(
        value=3,
        min=1,
        max=10,
        step=1,
        description='Count:',
        disabled=False
    ), 'metric': widgets.Combobox(
        value=logic.Globals.metricID,
        options=logic.Globals.allowed,
        description='Metric:',
        ensure_option=True,
        disabled=False
    ), 'delay': widgets.BoundedIntText(
        value=50,
        min=50,
        max=500,
        step=50,
        description='Display delay (ms):',
        disabled=False
    )}
inputs = getButtons()

def updateFigs(data):
    districts = [entry[2] for entry in data]
    # TODO: distinct choropleth traces?
    plots.update_traces(z=districts, selector=dict(type='choropleth'))
    plots.update_traces(labels=districts, selector=dict(type='pie'))

def updateFrame(data):
    updateFigs(data)
    inputs['advance'].disabled = False
    time.sleep(inputs['delay'].value/1000)

def advance(bt):
    bt.disabled = True
    global state
    state = logic.doStep(state)

def runSolver(bt):
    bt.disabled = True
    inputs['advance'].disabled = False
    global state
    state = logic.solve(numDist=inputs['count'].value, scale=inputs['scale'].value, metricID=inputs['metric'].value, callback=updateFrame)
    bt.disabled = False
    inputs['advance'].disabled = True

def updateScale(change):
    global inputs, plots
    logic.Globals.set(metricID=inputs['metric'].value, scale=inputs['scale'].value, callback=updateFrame)
    plots = getCharts()
    display(widgets.VBox([widgets.HBox([*inputs.values()]), plots]))


inputs['advance'].on_click(advance)
inputs['solve'].on_click(runSolver)
# TODO: fix
#inputs['scale'].observe(updateScale)

display(widgets.VBox([widgets.HBox([*inputs.values()]), plots]))