In [None]:
!pip install dash
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import scipy.stats
import geopandas as gpd
%matplotlib inline
from IPython.display import Markdown
from functools import reduce
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from dash import Dash, dcc, html, Input, Output, State, ALL, MATCH
from dash.exceptions import PreventUpdate
pd.options.display.float_format = '{:,.2f}'.format

In [None]:
if 0:
    style = {
        "background-color": "#1b1b1b",  # rgb(27, 27, 27)
        "color": "white",  # font
    }
    pio.templates.default = "plotly_dark"
else:
    style = {}
    pio.templates.default = "plotly"

In [None]:
from os.path import join

# repo = "./"
repo = "https://raw.githubusercontent.com/nolasemon/slovakia-education/main"
data_path = join(repo, "data")
!curl -O {join(repo, "preprocessing.py")}

import preprocessing

table_names = [
    "RV_O_010_L_OK_SK.CSV",
    "RV_O_040_L_OK_SK.CSV",
    "RV_O_047_L_OK_SK.CSV",
    "RV_O_067_L_OK_SK.CSV",
]
tables = [
    preprocessing.preprocess(pd.read_csv(join(data_path, table), sep=";"))
    for table in table_names
]
table_10, table_40, table_47, table_67 = tables

# https://bbrejova.github.io/viz/data/districts.json
districts_url = join(data_path, 'districts.json')
districts = gpd.read_file(districts_url)
# https://raw.githubusercontent.com/drakh/slovakia-gps-data/master/GeoJSON/epsg_4326/districts_epsg_4326.geojson
districts_geojson_url = join(data_path, "districts.geojson")
districts_geojson = gpd.read_file(districts_geojson_url, crs="EPSG:4326")

In [None]:
districts_geojson_indexed = districts_geojson.set_index("IDN3")
districts_indexed = districts.set_index("IDN3")
districts_indexed[["geometry", "Area", "AreaHA"]] = districts_geojson_indexed[
    ["geometry", "Shape_Area", "VYMERA_ha"]
]
geo_frame = districts_indexed

In [None]:
districts_indexed = districts.set_index("LAU1_CODE")
for table in tables:
    if "LAU1_CODE" in table:
        table.set_index("LAU1_CODE", inplace=True)
        table[["region_name", "NUTS3_CODE", "ecoregion_name", "NUTS2_CODE"]] = (
            districts_indexed[["NUTS3", "NUTS3_CODE", "NUTS2", "NUTS2_CODE"]]
        )
        table.reset_index(inplace=True)

In [None]:
for i in range(len(tables)):
    object_columns = [column for column in tables[i].columns if tables[i][column].dtype == 'object']
    tables[i][object_columns] = tables[i][object_columns].astype('string')
    tables[i][object_columns] = tables[i][object_columns].astype('category')

In [None]:
def compute_groups(data, groupby, chosen_query="", filter_query=""):
    """
    Arguments
        filter_query - filters data
        chosen_query - will be used to count ratio
    """
    if filter_query != "":
        data = data.query(filter_query)
    if chosen_query != "":
        selected = data.query(chosen_query)
    else:
        selected = data
    aggregated = selected.groupby(groupby, observed=True)["count"].sum().rename("number").to_frame()
    aggregated["number_percent"] = aggregated["number"] / selected["count"].sum() * 100
    aggregated["total"] = data.groupby(groupby, observed=True)["count"].sum()
    aggregated["percent"] = aggregated["number"] / aggregated["total"] * 100
    aggregated = aggregated.reset_index()
    return aggregated


def plot_groups(data, groupby, value, title=""):
    figure = None
    if groupby not in ["NUTS2_CODE", "NUTS3_CODE", "LAU1_CODE"]:
        data = data.sort_values(by=value, ascending=False)
        if value == "percent":
            figure = px.bar(
                data,
                y=groupby,
                color=groupby,
                x="percent",
                orientation="h",
                hover_data=["number", "percent"],
            )
        elif value == "number":
            figure = px.treemap(
                data,
                path=[px.Constant("all"), groupby],
                values="number",
                hover_data=["number", "number_percent"],
            )
    else:
        merged = geo_frame.merge(data, on=groupby)
        figure = px.choropleth_mapbox(
            merged,
            geojson=merged.geometry,
            locations=merged.index,
            color=value,
            mapbox_style="carto-positron",
            center={"lat": 48.6737532, "lon": 19.696058},
            zoom=7,
            opacity=0.5,
            hover_data=["LAU1", "number", "percent"],
        )
    figure.update_layout(title=title)
    return figure


data = compute_groups(
    table_40,
    groupby="region_name",
    chosen_query="`education` == 'vysokoškolské vzdelanie - 1. stupeň (Bc.)'",
)
fig = plot_groups(data, groupby="region_name", value="percent")
fig.show()

In [None]:
ATTR_SELECTOR_MAP = {
    # lambda for lazyness
    "category": lambda data, attr, type: dcc.Dropdown(
        # Mark selector element with type and attr to find then
        id={"type": type, "attr": attr},
        # attr must be of type `category`
        options=data[attr].cat.categories,
        persistence=True,
        multi=True,
    ),
    "int64": lambda data, attr, type: dcc.RangeSlider(
        id={"type": type, "attr": attr},
        min=data[attr].min(),
        max=data[attr].max() + 1,
        step=1,
        marks={
            i: str(i)
            for i in range(
                data[attr].min(),
                data[attr].max() + 1,
                (data[attr].max() - data[attr].min()) // 10,
            )
        },
        value=[data[attr].min(), data[attr].max() + 1],
        persistence=True,
    ),
}

ATTR_QUERY_EXPR_MAP = {
    # value of a Dropdown is an array of option values
    "category": lambda attr, entry: (f"`{attr}`.isin({entry})"),
    # value of a Range Slider is an array of two boundaries
    "int64": lambda attr, entry: f"{entry[0]} <= `{attr}` < {entry[1]}",
}

TEST_ATTR_FILTER_MAP = {
    "category": lambda entry: len(entry) > 0,
    "int64": lambda entry: len(entry) == 2,
}


def get_selectivity(data, attributes, type):
    return [
        html.Div(
            [
                html.H4(f"Select {data[attr].name}"),
                # Choose appropriate 'selector' according to attr type
                # Cause error by calling None in case of unmatched type
                ATTR_SELECTOR_MAP.get(str(data[attr].dtype).lower(), None)(
                    data, attr, type
                ),
            ]
        )
        for attr in attributes
    ]


def form_query(data, selectors, attrs):
    return " and ".join(
        [
            ATTR_QUERY_EXPR_MAP.get(str(data[attr].dtype).lower(), None)(attr, entry)
            for attr, entry in zip(attrs, selectors)
            if entry is not None
            and TEST_ATTR_FILTER_MAP.get(str(data[attr].dtype).lower(), None)(entry)
        ]
    )


def big_annotation(text: str, color: str):
    return dict(
        name="draft watermark",
        text=text.upper(),
        textangle=-30,
        opacity=0.1,
        font=dict(color=color, size=100),
        xref="paper",
        yref="paper",
        x=0.5,
        y=0.5,
        showarrow=False,
    )

In [None]:
app = Dash(__name__)

figure = go.Figure()
figure.add_annotation(big_annotation("START", "black"))

app.layout = html.Div(
    [
        html.Div(
            [
                html.H4("Enter title"),
                dcc.Input(
                    id="title", type="text", persistence=True, style={"width": "100%"}
                ),
                html.H4("Select table"),
                dcc.Dropdown(
                    id="table-index",
                    options=[
                        {"value": v, "label": l} for v, l in enumerate(table_names)
                    ],
                    persistence=True,
                ),
                html.H4("Select groupby"),
                dcc.Dropdown(
                    id="groupby",
                    persistence=True,
                ),
                html.H4("Select display value"),
                dcc.Dropdown(
                    ["number", "percent"],
                    "number",
                    id="display-value",
                    persistence=True,
                ),
                html.H4("Select chosen/percented attributes"),
                dcc.Dropdown(
                    id="chosen-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Select filter attributes"),
                dcc.Dropdown(
                    id="filter-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Choose zone"),
                html.Div(id="choose-zone"),
                html.H4("Filter zone"),
                html.Div(id="filter-zone"),
            ],
            style={"flex": 1, "minWidth": 400, "padding": 10},
        ),
        # html.Br(),
        html.Div(
            [
                dcc.Graph(
                    id="line-plot",
                    style={"aspect-ratio": "1.6"},
                    figure=figure,
                ),
                dcc.Textarea(id="function-call", style={"width": "100%", "disabled": True}),
            ],
            style={"flex": 2, "padding": 10},
        ),
    ],
    style=style | {"padding": 10, "display": "flex", "flexDirection": "row"},
)


@app.callback(
    Output("chosen-attributes", "options"),
    Output("filter-attributes", "options"),
    Output("groupby", "options"),
    Input("table-index", "value"),
)
def update_attributes(table_index):
    if table_index is None:
        raise PreventUpdate
    return [
        {
            column: f"{tables[table_index][column].dtype}: {column}"
            for column in tables[table_index].columns
        }
    ] * 3


# update filters and groupby
@app.callback(
    Output("choose-zone", "children"),
    Output("filter-zone", "children"),
    Input("chosen-attributes", "value"),
    Input("filter-attributes", "value"),
    State("table-index", "value"),
)
def update_fg(chosen_attributes, filter_attributes, table_index):
    if chosen_attributes is None or filter_attributes is None or table_index is None:
        raise PreventUpdate
    return [
        get_selectivity(tables[table_index], chosen_attributes, type="chosen"),
        get_selectivity(tables[table_index], filter_attributes, type="filter"),
    ]


@app.callback(
    Output("line-plot", "figure"),
    Output("function-call", "value"),
    Input("title", "value"),
    Input("display-value", "value"),
    Input("groupby", "value"),
    Input({"type": "chosen", "attr": ALL}, "value"),
    State({"type": "chosen", "attr": ALL}, "id"),
    Input({"type": "filter", "attr": ALL}, "value"),
    State({"type": "filter", "attr": ALL}, "id"),
    State("table-index", "value"),
)
def update_figure(
    title,
    display_value,
    groupby,
    chosen,
    chosen_id,
    filter,
    filter_id,
    table_index,
):
    if display_value is None or groupby is None or table_index is None:
        raise PreventUpdate

    data = tables[table_index]
    chosen_query = form_query(data, chosen, [a["attr"] for a in chosen_id])
    filter_query = form_query(data, filter, [a["attr"] for a in filter_id])
    figure = go.Figure()
    try:
        data = compute_groups(data, groupby, chosen_query, filter_query)
        figure: go.Figure = plot_groups(data, groupby, display_value, title)
    except Exception as e:
        figure.add_annotation(big_annotation("ERROR", "red"))
        print(e)
    return [
        figure,
        f"data = compute_groups({table_names[table_index]!r}, {groupby=!r}, {chosen_query=!r}, {filter_query=!r})\nfigure = plot_groups(data, {groupby=!r}, value={display_value!r}, {title=!r})",
    ]


app.run_server(port=8052, debug=True, use_reloader=True)