In [None]:
# Imports
import pandas as pd
import plotly.express as px
from dash import Dash, html, dcc, Input, Output
import numpy as np
import dash_bootstrap_components as dbc
import dash_daq as daq

# Helper function to read data
def read_traffic_data(path):
    data = pd.read_csv(f"{path}", index_col=0)
    data = data.infer_objects()
    data["vehicle_type"] = pd.Categorical(data["mobility_mode"].unique())
    data["timestep"] = pd.to_datetime(data["timestep"], format="%M").dt.minute
    return data

# Datasets and column types
current_traffic_data = read_traffic_data("kamppi/simulation_output/clean_data.csv")
optimized_traffic_data = read_traffic_data("kamppi/simulation_output/clean_data.csv")
# livability_data = pd.read_csv("...")

# Global variables
mobility_modes = sorted(current_traffic_data["mobility_mode"].unique())
timeline = current_traffic_data["timestep"].unique()
objectives = ["Summary", "Air quality", "Livability", "Traffic"]

# Traffic variables
traffic_variables = ["Traffic flow",
                     "Average speed",
                     "Traffic jams",
                     "Mobility modes",
                     "Noise"]
traffic_cols = {"Traffic density": "amount", 
"Average speed": "vehicle_speed",
"Traffic jams": "jam_prob",
"Mobility modes": "mobility_mode",
"Noise": "vehicle_noise"}
traffic_units = {"Traffic flow": "passengers",
"Average speed": "km/h",
"Traffic jams": "%",
"Mobility modes": "passengers",
"Noise": "dB"}
traffic_maxes = 0.7*current_traffic_data[traffic_cols.values()].max(axis=0)
traffic_maxes = traffic_maxes.to_dict()

# AQ variables
pm25_thresholds = [0, 10, 25, 50, 75, np.inf]
bc_thresholds = [0, 1, 3, 7, 12, np.inf]
ntot_thresholds = [0, 15, 30, 60, 100, np.inf] 
aqi_thresholds = {0: "good", 1: "satisfactory", 2: "fair", 3: "poor", 4: "very poor"}
aq_variables = ["Air quality",
                "Black carbon",
                "Carbon monoxide",
                "Carbon dioxide",
                "Hydrocarbon",
                "Nitrogen oxides",
                "Particle matters"]
aq_cols = {"Air quality": "aq_index", 
"Carbon monoxide": "vehicle_CO",
"Carbon dioxide": "vehicle_CO2",
"Hydrocarbon": "vehicle_HC",
"Nitrogen oxides": "vehicle_NOx",
"Particle matters": "vehicle_PMx"}
aq_units = {"Noise": "dB",
"Carbon monoxide": "mg",
"Carbon dioxide": "mg",
"Hydrocarbon": "mg",
"Nitrogen oxides": "mg",
"Particle matters": "mg"}
aq_maxes = 0.7*current_traffic_data[aq_cols.values()].max(axis=0)
aq_maxes = aq_maxes.to_dict()

# Livability variables

# Helper function to normalize to an array to a range (x,y)
def normalize_range(array, x, y):
    m = min(array)
    range = max(array) - m
    array = (array - m) / range
    range2 = y - x
    normalized = (array*range2) + x
    return list(normalized)

def filter_dataframe(data=current_traffic_data,
                     timestep_range=1,
                     spatial_scope=None,
                     emission="Carbon dioxide",
                     timeline_type="Average"):
    required_traffic_cols = ["vehicle_lane", "timestep", "lon", "lat", "mobility_mode", "amount"]
    emission_col = aq_cols[emission]
    required_traffic_cols.append(emission_col)
    timesteps = np.arange(timeline.min(), timeline.max(), timestep_range, dtype=int)
    timestep_labels = np.arange(timeline.min()+1, timeline.max()+1, timestep_range, dtype=int)
    network = data[required_traffic_cols].copy()
    if spatial_scope is not None:
        network = network[network["vehicle_lane"] == spatial_scope]
    network['timestep'] = pd.cut(current_traffic_data['timestep'], 
                                 bins=timesteps, 
                                 labels=timestep_labels, 
                                 right=False, 
                                 include_lowest=True)
    if timeline_type == "Average":
        func = "mean"
    else:
        func = "sum"
    network = network.groupby(["vehicle_lane", "timestep"], observed=False).agg(
        {emission_col: func, "amount": func, "lon": "first", "lat": "first"}
        ).reset_index()
    return network

# Center point for heatmap
center_point= dict(lon=current_traffic_data["lon"].mean(), lat = current_traffic_data["lat"].mean())

# Style
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP]

# Hover layout
hovers=dict(bgcolor="white", font_size=16)

# Main component
app = Dash(__name__,
           external_stylesheets=external_stylesheets,
           suppress_callback_exceptions=True)

app.layout = html.Div([

    # Viz params
    html.Div([

        # Title
        html.H2("Visualization parameters", style={"display": "inline-block"}),
        
        ]),

        # Toggles and texts for parameters
        dbc.Row([

            # Timeline type column
            dbc.Col(
                html.Div([
                    html.Label("Aggregation method", style={"padding-bottom": "2vh"}),
                    dcc.RadioItems(options=["Sum", "Average"],
                    value="Sum",
                    id="crossfilter-timeline-type",
                    inline=True, labelStyle={"automargin": True, "padding-top": "1vh", "font-size": 18, "padding-right": "1vw"})
                ])
            ),

            # Timestep interval column
            dbc.Col(
                html.Div([
                    html.Label("Timestep interval",
                    style={"padding-bottom": "2vh"}),
                    html.Div([
                        daq.NumericInput(value=1,
                        id="crossfilter-timestep-range",
                        min=timeline.min(),
                        max=timeline.max(),
                        style={"display": "inline-block", "font-size": 18, "padding-right": "1vw"}),
                        html.P("minute(s)",
                        style={"display": "inline-block", "font-size": 18})
                    ]),
                ])
            ),

        # Row end
        ], justify="center", style={"automargin": True, "padding-top": "2vh", "font-size": 24}),

        # Dividing line between the params and plots
        html.Hr(),

        # Title
        html.H2("Results", style={"padding-bottom": "2vh", "padding-top": "1vh"}),

        # Tabs
        dcc.Tabs(id='tabs-1', value='air-quality', children=[
        dcc.Tab(label='Summary', value='summary'),
        dcc.Tab(label='Traffic', value='traffic'),
        dcc.Tab(label='Air quality', value='air-quality'),
        dcc.Tab(label='Livability', value='livability')
        ]),
        html.Div(id='tabs-content-1', style={"padding-top": "1vh"})

# Main component
], style = {"width": "99vw", "automargin": True, "padding": "5vh 5vh 5vh 5vh"})

# Callbacks

# Tab update callback
@app.callback(
    Output('tabs-content-1', 'children'),
    Input('tabs-1', 'value')
)
def render_view_content(tab):
    if tab == 'air-quality':
        # Component for traffic plots
        return html.Center([

        # Current or optimized situation
        dcc.RadioItems(options=["Current situation", "Optimized situation"],
                    value="Current situation",
                    id="crossfilter-situation",
                    inline=True, labelStyle={"automargin": True, "padding-top": "1vh", "font-size": 18, "padding-right": "1vw"}),
            
        # Emission type dropdown
        html.Div([

        html.Div([
            html.Label("Emission type",
            style={"font-size": 24, "padding-bottom": "2vh", "padding-top": "1vh"}),
            dcc.Dropdown(options=aq_variables,
            value="Carbon dioxide",
            id="crossfilter-emission",
            style={"font-size": 18, "width": "50%", "margin": "auto", "text-align": "center"}),
        ], style={"padding-top": "2vh", "padding-bottom": "2vh"}),

        # Network plots
        html.Div([
            # The heatmap
            dcc.Loading(dcc.Graph(id="traffic-heatmap"), type="cube"),
        ], style={"padding-top": "3vh"}),

        # Drill-down plots   
        html.Div([
            # Title
            html.Label(html.Pre(id='location-text'),
            style={"font-size": 24, "padding-bottom": "4vh", "display": "block", "padding-top": "2vh"}),
            # Reset button
            html.Button("Reset location",
            id="reset_button",
            n_clicks=0,
            style={"display": "block", "font-size": 18, "automargin": True, "padding": "1vh"}),
            # Bar plot
            dcc.Loading(dcc.Graph(id="traffic-bar"), type='cube'),
            # Stream graph
            dcc.Loading(dcc.Graph(id="traffic-stream"), type='cube')
        ], style={"padding-top": "3vh"}),
                
            ])

        ])

    elif tab == 'summary':
        return html.Label("Summary under work", style={"padding-top": "2vh", "font-size": 20})
    elif tab == 'traffic':
        return html.Label("Traffic under work", style={"padding-top": "2vh", "font-size": 20})
    elif tab == 'livability':
        return html.Label("Livability under work", style={"padding-top": "2vh", "font-size": 20})

# Init callback to traffic heatmap
@app.callback(
    Output("traffic-heatmap", "figure"),
    Input("crossfilter-situation", "value"), 
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"))
def init_heatmap(situation, emission_name, timestep_range, timeline_type):
    if situation == "Current situation":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(data=data, emission=emission_column, timestep_range=timestep_range, timeline_type=timeline_type)
    max_bound = max(abs(network["lon"].max()-network["lon"].min()), abs(network["lat"].max()-network["lat"].min())) * 111
    zoom_level = np.round(14 - np.log(max_bound))
    fig = px.density_map(network, lat = "lat", lon = "lon", z = emission_name,
                            radius = normalize_range(network["amount"], 10, 18),
                            center = center_point,
                            zoom = zoom_level,
                            map_style = "open-street-map",
                            range_color = (0,aq_maxes[emission_column]),
                            hover_data = {"vehicle_lane": False, 
                                          "amount": True, 
                                          "lon": False, 
                                          "lat": False, 
                                          emission_column: True},
                            animation_frame="timestep",
                            labels={emission_column: emission_name,
                            "vehicle_type": "Vehicle type",
                            "amount": "Vehicle amount",
                            "timestep": "Timestep"},
                            height=750,
                            title="Network heatmap"
                            )
    sliders = [dict(
    active=0,
    currentvalue={"prefix": "Time: "},
    font={"size": 18}
    )]
    fig.layout["coloraxis"]["colorbar"]["title"] = f"{emission_name} ({aq_units[emission_name]})"
    fig["layout"]["uirevision"] = "Default"
    return fig

# Init location text
@app.callback(
        Output("location-text", "children"))
def init_location_text():
    return "Location: Network"

# Init callback to traffic bar plot
@app.callback(
    Output("traffic-bar", "figure"),
    Input("crossfilter-situation", "value"),
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"))
def init_bar(situation, emission_name, timestep_range, timeline_type):
    if situation == "Current situation":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(emission=emission_column, data=data, timestep_range=timestep_range, timeline_type=timeline_type)
    hist = network.groupby(["mobility_mode"], observed=False).agg(
        {emission_column: "sum", "amount": "sum"}
        ).reset_index()
    hist.sort_values(by=["mobility_mode"])
    hist_non_zero = hist[hist[emission_column] != 0]
    hist["average_vehicle"] = np.round(hist_non_zero[emission_column]/hist_non_zero["amount"], 4)
    hist_plot = px.bar(hist, 
    x=emission_column, 
    y="vehicle_type", 
    color="average_vehicle", 
    hover_data="amount", 
    labels={"average_vehicle": "Average vehicle emission",
    emission_column: emission_name,
    "vehicle_type": "Mobility mode",
    "amount": "Vehicle amount"}, 
    title="Mobility mode average")
    hist_plot.update_layout(hoverlabel=hovers, 
    title_x=0.11, bargap=0.5, 
    xaxis_title=f"{emission_name} ({aq_units[emission_name]})", 
    yaxis_title="Mobility mode", 
    xaxis=dict(range=[0, current_traffic_data[emission_column].max()]))
    return hist_plot

# Init callback to traffic stream
@app.callback(
    Output("traffic-stream", "figure"),
    Input("crossfilter-situation", "value"),
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"))
def init_stream(situation, emission_name, timestep_range, timeline_type):
    if situation == "Current situation":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(emission=emission_column, data=data, timestep_range=timestep_range, timeline_type=timeline_type)
    stream = network.groupby(["mobility_mode"], observed=False).agg(
        {emission_column: "sum", "amount": "sum"}
        ).reset_index()
    stream.sort_values(by=["mobility_mode"])
    stream.sort_values(by=["vehicle_type"])
    # stream["average_vehicle"] = stream["vehicle_type"].map(average_dict)
    stream = px.area(stream, 
    x="timestep", 
    y=emission_column, 
    color="vehicle_type", 
    color_discrete_sequence=px.colors.sequential.Plasma_r, 
    hover_data=["average_vehicle", "amount"],
    labels={"average_vehicle": "Average vehicle emission",
    emission_column: emission_name,
    "vehicle_type": "Mobility mode",
    "amount": "Amount"}, 
    title="Mobility mode time series")
    stream.update_layout(hoverlabel=hovers,
    title_x=0.11,
    yaxis_title=f"{emission_name} ({aq_units[emission_name]})")
    return stream

# Update callback to traffic heatmap
@app.callback(
    Output("x-heatmap", "figure"), 
    Input("crossfilter-situation", "value"),
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"))
def update_heatmap(situation, emission_name, timestep_range, timeline_type):
    if situation == "Current situation":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(emission=emission_column, data=data, timestep_range=timestep_range, timeline_type=timeline_type)
    max_bound = max(abs(network["lon"].max()-network["lon"].min()), abs(network["lat"].max()-network["lat"].min())) * 111
    zoom_level = np.round(14 - np.log(max_bound))
    fig = px.density_map(network, lat = "lat", lon = "lon", z = emission_column,
                            radius = normalize_range(network["amount"], 10, 18),
                            center = center_point,
                            zoom = zoom_level,
                            map_style = "open-street-map",
                            range_color = (0,aq_maxes[emission_column]),
                            hover_data = {"vehicle_lane": False, "amount": True, "lon": False, "lat": False, emission_column: True},
                            animation_frame="timestep",
                            labels={emission_column: emission_name,
                            "vehicle_type": "Vehicle type",
                            "amount": "Vehicle amount",
                            "timestep": "Timestep"},
                            height=750,
                            title="Network heatmap"
                            )
    sliders = [dict(
    active=0,
    currentvalue={"prefix": "Time: "},
    font={"size": 18}
    )]
    fig.layout["coloraxis"]["colorbar"]["title"] = f"{emission_name} ({aq_units[emission_name]})"
    fig["layout"]["uirevision"] = "Default"
    return fig

# Update callback to traffic bar plot
@app.callback(
    Output("traffic-bar", "figure"),
    Input("traffic-heatmap", "clickData"),
    Input("crossfilter-situation", "value"),
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"),
    prevent_initial_callback = True)
def update_bar(situation, emission_name, timestep_range, timeline_type):
    if situation == "Current situation":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(emission=emission_column, data=data, timestep_range=timestep_range, timeline_type=timeline_type)
    
    road_hist = network.groupby(["vehicle_type"], observed=False).agg(
        {emission_column: "sum", "amount": "sum"}
        ).reset_index()
    road_hist.sort_values(by=["vehicle_type"])
    road_hist_non_zero = road_hist[road_hist[emission_column] != 0]
    road_hist["average_vehicle"] = np.round(road_hist_non_zero[emission_column]/road_hist_non_zero["amount"], 4)
    average_dict = {
        road_hist["vehicle_type"][0]: road_hist["average_vehicle"][0], 
        road_hist["vehicle_type"][1]: road_hist["average_vehicle"][1]
    }
    hist = px.bar(road_hist, 
    x=emission_column, 
    y="vehicle_type", 
    color="average_vehicle", 
    hover_data="amount", 
    labels={"average_vehicle": "Average vehicle emission",
    emission_column: emission_name,
    "vehicle_type": "Mobility mode",
    "amount": "Vehicle amount"}, 
    title="Mobility mode average")
    hist.update_layout(hoverlabel=hovers, 
    title_x=0.11, bargap=0.5, 
    xaxis_title=f"{emission_name} ({aq_units[emission_name]})", 
    yaxis_title="Mobility mode", 
    xaxis=dict(range=[0, current_traffic_data[emission_column].max()]))
    return hist

# Update callback to traffic stream
@app.callback(
    Output("traffic-stream", "figure"),
    Input("traffic-heatmap", "clickData"),
    Input("crossfilter-situation", "value"),
    Input("crossfilter-emission", "value"),
    Input("crossfilter-timestep-range", "value"),
    Input("crossfilter-timeline-type", "value"))
def update_stream(situation, clickData, emission_name, timestep_range, timeline_type):
    if situation == "Current":
        data = current_traffic_data
    else:
        data = optimized_traffic_data
    emission_column = aq_cols[emission_name]
    network = filter_dataframe(emission=emission_column, data=data, timestep_range=timestep_range, timeline_type=timeline_type)
    stream = network.groupby(["vehicle_type", "timestep"], observed=False).agg(
        {emission_column: "sum", "amount": "sum"}
        ).reset_index()
    stream.sort_values(by=["vehicle_type"])
    # stream["average_vehicle"] = stream["vehicle_type"].map(average_dict)
    stream = px.area(stream, 
    x="timestep", 
    y=emission_column, 
    color="vehicle_type", 
    color_discrete_sequence=px.colors.sequential.Plasma_r, 
    hover_data=["average_vehicle", "amount"],
    labels={"average_vehicle": "Average vehicle emission",
    emission_column: emission_name,
    "vehicle_type": "Mobility mode",
    "amount": "Amount"}, 
    title="Mobility mode time series")
    stream.update_layout(hoverlabel=hovers,
    title_x=0.11,
    yaxis_title=f"{emission_name} ({aq_units[emission_name]})")
    return stream

# Update location text
@app.callback(
        Input("traffic-heatmap", "clickData"),
        Output("location-text", "children"),
        prevent_initial_callback=True)
def update_location_text(clickData):
    if clickData is None:
        return "Location: Network"
    else:
        road_lon = clickData["points"][0]["lon"]
        road_lat = clickData["points"][0]["lat"]
        road_num = clickData["points"][0]["pointNumber"]
        road_name = f"Location: ({road_lat:.2f} °N, {road_lon:.2f} °E), point {road_num}"
        return road_name

# Reset the heatmap
@app.callback(Output("x-heatmap","clickData"),
             [Input("reset_button","n_clicks")])
def reset_heatmap(reset):
    return None

app.run(debug=True)

KeyError: 'mobility_mode'