In [None]:
import ipywidgets as widgets
import logic_iterative as logic
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import json
import time

logic.Globals.set(0, 0)


def getCharts() -> go.FigureWidget:
    with open("assets/" + logic.Globals.scale + "/geo.json", encoding="utf8") as file:
        geojson = json.load(file)

    colors = px.colors.qualitative.Vivid

    frame = logic.State.getDummyData()

    fig = make_subplots(
        rows=1,
        cols=2,
        column_widths=[2, 1],
        specs=[[{"type": "choropleth"}, {"type": "pie"}]],
    )
    fig.add_choropleth(
        geojson=geojson,
        locations=frame["code"],
        z=frame["group"],
        text=frame["metric"],
        zmin=0,
        zmax=len(colors),
        legendgroup="group",
        colorscale=colors,
        row=1,
        col=1,
    )
    fig.update_layout(geo={"scope": "usa"}, height=700)
    fig.add_pie(
        values=frame["metric"],
        labels=frame["group"],
        textinfo="label+value+percent",
        legendgroup="group",
        marker=dict(colors=colors),
        row=1,
        col=2,
    )
    return go.FigureWidget(fig)


plots = getCharts()


# Pause button?
def getButtons():
    return {
        "advance": widgets.Button(description="Advance", disabled=True),
        "solve": widgets.Button(description="Solve"),
        "scale": widgets.Combobox(
            value=logic.Globals.scale,
            options=logic.Globals.scales,
            description="Scale:",
            ensure_option=True,
            disabled=False,
        ),
        "count": widgets.BoundedIntText(value=3, min=1, max=10, step=1, description="Count:", disabled=False),
        "metric": widgets.Combobox(
            value=logic.Globals.metricID,
            options=logic.Globals.allowed,
            description="Metric:",
            ensure_option=True,
            disabled=False,
        ),
        "delay": widgets.BoundedIntText(
            value=50,
            min=50,
            max=500,
            step=50,
            description="Display delay (ms):",
            disabled=False,
        ),
    }


inputs = getButtons()


def updateFigs(data):
    metrics = [entry[1] for entry in data]
    groups = [entry[2] for entry in data]
    # TODO: distinct choropleth traces?
    plots.update_traces(z=groups, text=metrics, selector=dict(type="choropleth"))
    plots.update_traces(labels=groups, values=metrics, selector=dict(type="pie"))


def updateFrame(data):
    updateFigs(data)
    inputs["advance"].disabled = False
    time.sleep(inputs["delay"].value / 1000)


def advance(bt):
    bt.disabled = True
    global state
    state = logic.doStep(state)


def runSolver(bt):
    bt.disabled = True
    inputs["advance"].disabled = False
    global state
    state = logic.solve(
        numGroup=inputs["count"].value,
        scale=inputs["scale"].value,
        metricID=inputs["metric"].value,
        callback=updateFrame,
    )
    bt.disabled = False
    inputs["advance"].disabled = True


def updateScale(_):
    global inputs, plots
    if logic.Globals.scale != inputs["scale"].value:
        logic.Globals.set(
            metricID=inputs["metric"].value,
            scale=inputs["scale"].value,
            callback=updateFrame,
        )
        plots = getCharts()
        display(widgets.VBox([widgets.HBox([*inputs.values()]), plots]))


inputs["advance"].on_click(advance)
inputs["solve"].on_click(runSolver)
inputs['scale'].observe(updateScale)

display(widgets.VBox([widgets.HBox([*inputs.values()]), plots]))


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import cartopy.crs as ccrs
import cartopy.io.shapereader as shpreader
from cartopy.feature import ShapelyFeature
import logic_iterative as logic

colors = [ '#D3D3D3', '#462255', '#313B72', '#62A87C', '#DE4D86', '#FCD0A1' ]

def makeBlankMap():
    # Initialize map
    projection = ccrs.LambertConformal()
    fig, _ = plt.subplots(figsize=(20, 15))

    maps = {}

    # Continental United States
    maps['48'] = fig.add_axes([-.05, -.05, 1.2, 1.2], projection=projection)
    maps['48'].set_extent([-125, -66.5, 20, 50], ccrs.Geodetic())

    # Hawaii
    maps['HI'] = fig.add_axes([0.25, .1, 0.15, 0.15], projection=projection)
    maps['HI'].set_extent([-155, -165, 20, 15], ccrs.Geodetic())

    # Alaska
    maps['AK'] = fig.add_axes([0.1, 0.1, 0.2, 0.2], projection=projection)
    maps['AK'].set_extent([-185, -130, 70, 50], ccrs.Geodetic())

    return maps

def configureLegend(ax, data, colors):
    handles = []
    for group in data.groups:
        handles.append(mpatches.Rectangle((0, 0), 1, 1, facecolor=colors[group.index]))
        ax.legend(handles, (group.index for group in data.groups),
                   loc='lower left', bbox_to_anchor=(0.025, -0.0), 
                   fancybox=True, frameon=False, fontsize=15)

    ax.set_title(logic.Globals.metricID, fontsize=20)

def initializeMap(data):
    #get blank map
    maps = makeBlankMap()
    configureLegend(maps['48'], data, colors=colors)
    reader = shpreader.Reader("assets/" + logic.Globals.scale + "/shapefile/geo.shp")
    
    records = list(reader.records())
    shapes = { records[i].attributes['FID']: 
              ShapelyFeature(x, ccrs.PlateCarree()) for i, x in enumerate(reader.geometries()) }

    return maps, shapes

def drawOneUnit(code, shape, placement, maps):
    if code == 'AK' or code.startswith("02"):
        a = maps['AK']
    elif code == 'HI' or code.startswith("15"):
        a = maps['HI']
    else:
        a = maps['48']

    a.add_feature(shape, color=colors[placement], linewidth=0)

def drawUnits(data, maps, shapes):
    placements = data.getPlacements()
    for code, shape in shapes.items():
        drawOneUnit(code, shape, placements.get(code, 0), maps)

drawUnits(logic.solve(numGroup=2, scale=1, metricID=0), *initializeMap(state))