In [13]:
# -*- coding: utf-8 -*-
"""
Interaktive Deutschlandkarte mit Seitenmenü (Dash)

- links: Checkbox-Menü (Ebenen, Baustoffklasse, Baujahr, Straßenart, Flächenanzeige)
- rechts: Plotly-Karte
- Brücken werden nur angezeigt, wenn sie ALLE gewählten Filter erfüllen
  (Baustoffklasse UND Baujahr-Range UND Straßenart).
- Choropleth-Fläche kann wahlweise Ø Zustandsnote, Ø Traglastindex, Ø Alter oder Ø Substanzkennzahl anzeigen.
- Skala (Colorbar + Farbnormierung) wird aus MIN/MAX der *gefilterten Brückenpunkte* bestimmt
  (Fallback: MIN/MAX aus Kreis-Mittelwerten, wenn keine Punkte übrig sind)
- NEU (Variante A): Brückenpunkte farbkodiert nach Zustandsnote (grün gut -> rot schlecht)

Starten:
    python bruecken_app.py
Browser:
    http://127.0.0.1:8051
"""

import re
import json
import unicodedata
from pathlib import Path

import pandas as pd
import geopandas as gpd
from shapely.geometry import Point

import plotly.graph_objects as go

import dash
from dash import dcc, html
from dash.dependencies import Input, Output


# --------------------------------------------------------------------
# 0) Pfade
# --------------------------------------------------------------------
csv_path = Path("data/original_bridge_statistic_germany.csv")
# csv_path = Path("data/filled_bridge_statistic_germany.csv")
geo_path = Path("landkreise_simplify200.geojson")
states_path = Path("bundeslaender_simplify0.geojson")

# Referenzjahr für Altersberechnung
AGE_REF_YEAR = 2025


# --------------------------------------------------------------------
# Helper
# --------------------------------------------------------------------
def pick_first(cols, candidates):
    for c in candidates:
        if c in cols:
            return c
    return None


def norm_name(s: str) -> str:
    s = str(s).strip().lower()
    s = s.replace("ä", "ae").replace("ö", "oe").replace("ü", "ue").replace("ß", "ss")
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    s = re.sub(r"\b(landkreis|kreisfreie stadt|kreis|stadt|region|lkr\.?)\b", "", s)
    s = re.sub(r"\s+", " ", s)
    return s.strip()


def make_points_from_lonlat(dfin, lon_col="x2", lat_col="y2"):
    lon = pd.to_numeric(dfin[lon_col].astype(str).str.replace(",", ".", regex=False), errors="coerce")
    lat = pd.to_numeric(dfin[lat_col].astype(str).str.replace(",", ".", regex=False), errors="coerce")
    ok = lon.notna() & lat.notna()
    g = gpd.GeoDataFrame(
        dfin[ok].copy(),
        geometry=[Point(xy) for xy in zip(lon[ok], lat[ok])],
        crs=4326,
    )
    return g


def make_points_from_utm(dfin, x_col="X", y_col="Y", epsg=25832):
    xx = pd.to_numeric(dfin[x_col].astype(str).str.replace(",", ".", regex=False), errors="coerce")
    yy = pd.to_numeric(dfin[y_col].astype(str).str.replace(",", ".", regex=False), errors="coerce")
    ok = xx.notna() & yy.notna()
    g = gpd.GeoDataFrame(
        dfin[ok].copy(),
        geometry=[Point(xy) for xy in zip(xx[ok], yy[ok])],
        crs=epsg,
    ).to_crs(4326)
    return g


def safe_minmax(series: pd.Series):
    s = pd.to_numeric(series, errors="coerce").dropna()
    if len(s) == 0:
        return None
    vmin = float(s.min())
    vmax = float(s.max())
    if vmin == vmax:
        eps = 1e-6 if vmin == 0 else abs(vmin) * 1e-6
        vmin -= eps
        vmax += eps
    return vmin, vmax


# NEU: komfortabler Fallback, wenn alle Werte NaN sind
def safe_minmax_with_fallback(series: pd.Series, fallback=(1.0, 4.0)):
    rng = safe_minmax(series)
    if rng is None:
        return fallback
    return rng


# --------------------------------------------------------------------
# 1) Landkreise & Bundesländer laden
# --------------------------------------------------------------------
gdf_kreise = gpd.read_file(geo_path)

colsK = gdf_kreise.columns
col_name = pick_first(colsK, ["GEN", "NAME_3", "NAME", "GEN_NAME", "KREIS"])
col_ags = pick_first(colsK, ["AGS", "AGS_0", "RS", "RS_0", "ID_3"])

if col_name is None:
    raise ValueError(f"Konnte keinen Kreisnamen finden. Spalten: {list(colsK)}")

# CRS vereinheitlichen (WGS84)
if gdf_kreise.crs is None:
    gdf_kreise = gdf_kreise.set_crs(4326)
else:
    gdf_kreise = gdf_kreise.to_crs(4326)

# Bundesländer
gdf_states = gpd.read_file(states_path)
if gdf_states.crs is None:
    gdf_states = gdf_states.set_crs(4326)
else:
    gdf_states = gdf_states.to_crs(4326)

colsS = gdf_states.columns
state_name_col = pick_first(colsS, ["GEN", "NAME", "STATE_NAME"])
state_id_col = pick_first(colsS, ["RS", "RS_0", "AGS", "AGS_0", "ID"])
if state_id_col is None:
    state_id_col = state_name_col
gdf_states["state_id"] = gdf_states[state_id_col].astype(str)


# --------------------------------------------------------------------
# 2) Brücken laden & Punkte erzeugen
# --------------------------------------------------------------------
df = pd.read_csv(csv_path, sep=";", decimal=",", dtype=str, keep_default_na=False)

# Sicherstellen, dass Spalten existieren
for sp in [
    "Zustandsnote", "Substanzkennzahl",
    "x2", "y2", "X", "Y",
    "Kreis", "Baujahr Überbau", "Traglastindex",
    "Zugeordneter Sachverhalt vereinfacht", "Baustoffklasse", "Bauwerksname"
]:
    if sp not in df.columns:
        df[sp] = ""

# Zustandsnote numerisch (Dezimalkomma -> Punkt)
df["Zustandsnote"] = pd.to_numeric(
    df["Zustandsnote"].astype(str).str.replace(",", ".", regex=False),
    errors="coerce"
)

# Substanzkennzahl numerisch (Dezimalkomma -> Punkt)
df["Substanzkennzahl"] = pd.to_numeric(
    df["Substanzkennzahl"].astype(str).str.replace(",", ".", regex=False),
    errors="coerce"
)

# Baujahr Überbau numerisch
df["Baujahr Überbau"] = pd.to_numeric(
    df["Baujahr Überbau"].astype(str).str.replace(",", ".", regex=False),
    errors="coerce"
)

# bevorzugt WGS84 (x2/y2), sonst UTM
gdf_pts = make_points_from_lonlat(df, "x2", "y2")
used_coords = "WGS84 (x2/y2)"
if len(gdf_pts) == 0:
    gdf_pts = make_points_from_utm(df, "X", "Y", epsg=25832)
    used_coords = "UTM32 (EPSG:25832)"

gdf_pts["lon"] = gdf_pts.geometry.x
gdf_pts["lat"] = gdf_pts.geometry.y

# Baujahr in Punkte übernehmen (für Filter)
gdf_pts["Baujahr"] = pd.to_numeric(gdf_pts["Baujahr Überbau"], errors="coerce")

# räumlicher Join (optional; hier nur für base/Mapping)
spatial_join_ok = len(gdf_pts) > 0
if spatial_join_ok:
    joined = gpd.sjoin(
        gdf_pts,
        gdf_kreise[[col_name, col_ags, "geometry"]] if col_ags is not None else gdf_kreise[[col_name, "geometry"]],
        how="left",
        predicate="within",
    )
else:
    joined = None


# --------------------------------------------------------------------
# 3) Kreis-Aggregation (für Choropleth) – inkl. Traglastindex, Alter, Substanzkennzahl
# --------------------------------------------------------------------
kreise_names = gdf_kreise[[col_name]].copy()
kreise_names["__kreis_norm__"] = kreise_names[col_name].map(norm_name)

# Basis-DataFrame: zuerst räumlicher Join, sonst Roh-CSV
base = joined if spatial_join_ok else df.copy()

base["__br_kreis_norm__"] = base["Kreis"].map(norm_name)
name_map = dict(zip(kreise_names["__kreis_norm__"], gdf_kreise[col_name]))
map_series = base["__br_kreis_norm__"].map(name_map)

if spatial_join_ok:
    if col_name in base.columns:
        base[col_name] = base[col_name].fillna(map_series)
    else:
        base[col_name] = map_series
else:
    base[col_name] = map_series

# Traglastindex (römische Ziffern -> 1–5)
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
base["Traglastindex_norm"] = (
    base["Traglastindex"]
    .astype(str)
    .str.upper()
    .str.strip()
    .map(roman_map)
)

# Alter des Überbaus (in Jahren)
base["Alter_Überbau"] = AGE_REF_YEAR - base["Baujahr Überbau"]

# Aggregation pro Landkreis
agg = (
    base
    .dropna(subset=[col_name])
    .groupby(col_name, as_index=False)
    .agg(
        n_bridges=("Bauwerksname", "size"),
        mean_zust=("Zustandsnote", "mean"),
        mean_trag=("Traglastindex_norm", "mean"),
        mean_age=("Alter_Überbau", "mean"),
        mean_substanz=("Substanzkennzahl", "mean"),
        median_zust=("Zustandsnote", "median"),
    )
)

# Keys/Join zurück in GeoDF
if col_ags is not None:
    kreise_keys = gdf_kreise[[col_name, col_ags]].drop_duplicates()
    agg = agg.merge(kreise_keys, on=col_name, how="left")
    agg = agg[[col_ags, col_name, "n_bridges", "mean_zust", "mean_trag", "mean_age", "mean_substanz", "median_zust"]].rename(
        columns={col_name: "NAME"}
    )
else:
    agg = agg[[col_name, "n_bridges", "mean_zust", "mean_trag", "mean_age", "mean_substanz", "median_zust"]].rename(
        columns={col_name: "NAME"}
    )

agg = agg.sort_values("NAME").reset_index(drop=True)

gplot = gdf_kreise.merge(agg, left_on=col_name, right_on="NAME", how="left")

id_col = None
if col_ags is not None:
    for c in [col_ags, f"{col_ags}_x", f"{col_ags}_y"]:
        if c in gplot.columns:
            id_col = c
            break

if id_col is not None:
    gplot["id"] = gplot[id_col].astype(str)
else:
    gplot["id"] = gplot["NAME"].astype(str)

gplot_json = gplot[["id", "geometry"]].copy()
geojson_kreise = json.loads(gplot_json.to_json())

states_json = gdf_states[["state_id", "geometry"]].copy()
geojson_states = json.loads(states_json.to_json())

germany_outline = gdf_states.dissolve().reset_index(drop=True)
germany_outline["id"] = "germany_outline"
outline_json = json.loads(germany_outline[["id", "geometry"]].to_json())


# --------------------------------------------------------------------
# 4) Brücken-Attribute vorbereiten (Baustoff, Baujahr, Straßenart)
# --------------------------------------------------------------------
zust = pd.to_numeric(gdf_pts["Zustandsnote"], errors="coerce")
subs = pd.to_numeric(gdf_pts["Substanzkennzahl"], errors="coerce")

# Straßenart-Mapping
mapping_strasse = {
    "B": "Bundesstraße",
    "O": "Bundesstraße",
    "A": "Autobahn",
}
strassen_roh = gdf_pts["Zugeordneter Sachverhalt vereinfacht"].astype(str)
gdf_pts["StraßenartLabel"] = strassen_roh.map(lambda x: mapping_strasse.get(x, x))

# Hover-Text
gdf_pts["hover"] = (
    "<b>" + gdf_pts["Bauwerksname"].astype(str) + "</b><br>"
    "Zustandsnote: " + zust.round(2).astype(str) + "<br>"
    "Substanzkennzahl: " + subs.round(2).astype(str) + "<br>"
    "Baustoffklasse: " + gdf_pts["Baustoffklasse"].astype(str) + "<br>"
    "Baujahr Überbau: " + gdf_pts["Baujahr Überbau"].astype(str) + "<br>"
    "Straßenart: " + gdf_pts["StraßenartLabel"].astype(str)
)

# Ebenen-Checkbox
layer_options = [
    {"label": "Umriss",                        "value": "outline"},
    {"label": "Bundesländer",                  "value": "states"},
    {"label": "Zustand / Traglast / Alter / Substanz", "value": "choropleth"},
    {"label": "Landkreis-Grenzen",             "value": "kreise"},
]

# Baustoff-Checkboxen
baustoff_vals = sorted(gdf_pts["Baustoffklasse"].dropna().unique())
baustoff_options = [{"label": b, "value": b} for b in baustoff_vals]

# Straßenart-Checkboxen
road_vals = sorted(gdf_pts["StraßenartLabel"].dropna().unique())
road_options = [{"label": r, "value": r} for r in road_vals]

# --- Baujahr-Range ---
BAUJAHR_MIN = 1900
if gdf_pts["Baujahr"].notna().any():
    BAUJAHR_MAX_DATA = int(gdf_pts["Baujahr"].max())
else:
    BAUJAHR_MAX_DATA = AGE_REF_YEAR

rest = (BAUJAHR_MAX_DATA - BAUJAHR_MIN) % 25
BAUJAHR_MAX = BAUJAHR_MAX_DATA + (25 - rest) if rest != 0 else BAUJAHR_MAX_DATA
baujahr_marks = {year: str(year) for year in range(BAUJAHR_MIN, BAUJAHR_MAX + 1, 25)}

# Radio für Flächenanzeige
choropleth_mode_options = [
    {"label": "Zustandsnote",        "value": "zustand"},
    {"label": "Traglastindex",       "value": "trag"},
    {"label": "Durchschnittsalter",  "value": "alter"},
    {"label": "Substanzkennzahl",    "value": "substanz"},
]


# --------------------------------------------------------------------
# 5) Hilfsfunktion: Figur bauen
# --------------------------------------------------------------------
def build_figure(selected_layers, selected_baustoff, selected_year_range, selected_road, choropleth_mode):
    fig = go.Figure()

    # 3) Choropleth: Ø Zustandsnote / Ø Traglastindex / Ø Alter / Ø Substanzkennzahl pro Landkreis
    if "choropleth" in selected_layers:
        if choropleth_mode == "trag":
            metric_col = "mean_trag"
            metric_label = "Ø Traglastindex"
            fmt = ".2f"
        elif choropleth_mode == "alter":
            metric_col = "mean_age"
            metric_label = "Ø Alter (Jahre)"
            fmt = ".1f"
        elif choropleth_mode == "substanz":
            metric_col = "mean_substanz"
            metric_label = "Ø Substanzkennzahl"
            fmt = ".2f"
        else:
            metric_col = "mean_zust"
            metric_label = "Ø Zustandsnote"
            fmt = ".2f"

        pts_scale = gdf_pts.copy()

        # gleiche Filter wie Punktanzeige
        if selected_baustoff:
            pts_scale = pts_scale[pts_scale["Baustoffklasse"].isin(selected_baustoff)]

        if (
            selected_year_range is not None
            and isinstance(selected_year_range, (list, tuple))
            and len(selected_year_range) == 2
        ):
            year_min, year_max = selected_year_range
            pts_scale = pts_scale[
                (pts_scale["Baujahr"].notna()) &
                (pts_scale["Baujahr"] >= year_min) &
                (pts_scale["Baujahr"] <= year_max)
            ]

        if selected_road:
            pts_scale = pts_scale[pts_scale["StraßenartLabel"].isin(selected_road)]

        # Kennzahl für Skala aus Punkten ableiten
        if choropleth_mode == "trag":
            roman_map_local = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
            pts_metric = (
                pts_scale["Traglastindex"]
                .astype(str).str.upper().str.strip()
                .map(roman_map_local)
            )
        elif choropleth_mode == "alter":
            bauj = pd.to_numeric(pts_scale["Baujahr Überbau"], errors="coerce")
            pts_metric = AGE_REF_YEAR - bauj
        elif choropleth_mode == "substanz":
            pts_metric = pd.to_numeric(pts_scale["Substanzkennzahl"], errors="coerce")
        else:
            pts_metric = pd.to_numeric(pts_scale["Zustandsnote"], errors="coerce")

        rng = safe_minmax(pts_metric)

        # Fallback: wenn keine Punkte nach Filter übrig sind, Skala aus Kreiswerten
        if rng is None:
            rng = safe_minmax(gplot[metric_col])

        vmin, vmax = rng if rng is not None else (None, None)

        gplot_valid = gplot[gplot[metric_col].notna()]
        gplot_missing = gplot[gplot[metric_col].isna()]

        # 3a) Kreise ohne Daten – hellgrau
        if len(gplot_missing) > 0:
            fig.add_trace(
                go.Choropleth(
                    geojson=geojson_kreise,
                    locations=gplot_missing["id"],
                    z=[0] * len(gplot_missing),
                    featureidkey="properties.id",
                    name="keine Daten",
                    showlegend=True,
                    showscale=False,
                    colorscale=[[0, "lightgrey"], [1, "lightgrey"]],
                    hoverinfo="skip",
                    marker_line_color="rgba(0,0,0,0)",
                    marker_line_width=0.0,
                )
            )

        # 3b) Kreise mit Daten – farbig (NEU: zmin/zmax)
        if len(gplot_valid) > 0:
            fig.add_trace(
                go.Choropleth(
                    geojson=geojson_kreise,
                    locations=gplot_valid["id"],
                    z=gplot_valid[metric_col],
                    featureidkey="properties.id",
                    name=metric_label + " (Landkreise)",
                    showlegend=True,
                    showscale=True,
                    colorbar_title=f"{metric_label}<br>min={vmin:{fmt}} / max={vmax:{fmt}}" if (vmin is not None and vmax is not None) else metric_label,
                    colorscale="RdYlGn_r",
                    # colorscale="hsv",
                    #colorscale="twilight",
                    zmin=vmin,
                    zmax=vmax,
                    marker_opacity=1.0,
                    hovertemplate=f"<b>%{{text}}</b><br>{metric_label}: %{{z:{fmt}}}<extra></extra>",
                    text=gplot_valid["NAME"],
                    marker_line_color="rgba(0,0,0,0)",
                    marker_line_width=0.0,
                )
            )

    # 2) Bundesländer
    if "states" in selected_layers:
        fig.add_trace(
            go.Choropleth(
                geojson=geojson_states,
                locations=gdf_states["state_id"],
                z=[0] * len(gdf_states),
                featureidkey="properties.state_id",
                name="Bundesländer",
                showlegend=True,
                showscale=False,
                hoverinfo="skip",
                colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]],
                marker_line_color="rgba(70,70,70,70.95)",
                marker_line_width=1.3,
            )
        )

    # 4) Landkreis-Grenzen
    if "kreise" in selected_layers:
        fig.add_trace(
            go.Choropleth(
                geojson=geojson_kreise,
                locations=gplot["id"],
                z=[0] * len(gplot),
                featureidkey="properties.id",
                name="Landkreis-Grenzen",
                showlegend=True,
                showscale=False,
                hoverinfo="skip",
                colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]],
                marker_line_color="rgba(170,170,170,0.9)",
                marker_line_width=0.4,
            )
        )

    # 5) Brücken filtern (UND-Logik)
    pts = gdf_pts.copy()

    # Baustoffklasse
    if selected_baustoff:
        pts = pts[pts["Baustoffklasse"].isin(selected_baustoff)]

    # Baujahr-Range
    if (
        selected_year_range is not None
        and isinstance(selected_year_range, (list, tuple))
        and len(selected_year_range) == 2
    ):
        year_min, year_max = selected_year_range
        pts = pts[
            (pts["Baujahr"].notna()) &
            (pts["Baujahr"] >= year_min) &
            (pts["Baujahr"] <= year_max)
        ]

    # Straßenart
    if selected_road:
        pts = pts[pts["StraßenartLabel"].isin(selected_road)]

    # Punkte anzeigen (Variante A: farbkodiert nach Zustandsnote)
    if len(pts) > 0:
        pts_zust = pd.to_numeric(pts["Zustandsnote"], errors="coerce")
        zmin_pts, zmax_pts = safe_minmax_with_fallback(pts_zust, fallback=(1.0, 4.0))

        fig.add_trace(
            go.Scattergeo(
                lon=pts["lon"],
                lat=pts["lat"],
                mode="markers",
                name="Brücken (gefiltert)",
                text=pts["hover"],
                hovertemplate="%{text}<extra></extra>",
                marker=dict(
                    size=3.0,              # ggf. kleiner machen bei sehr vielen Punkten (z.B. 2.0)
                    opacity=0.85,
                    color=pts_zust,        # <- Zustandsnote steuert die Farbe
                    cmin=zmin_pts,
                    cmax=zmax_pts,
                    colorscale="RdYlGn_r", # grün (gut) -> rot (schlecht)
                    showscale=True,
                    colorbar=dict(
                        title="Zustandsnote<br>(grün gut / rot schlecht)"
                    ),
                    line=dict(
                    color="rgba(0,0,0,1)",  # Randfarbe (schwarz, leicht transparent)
                    width=0.25                 # Randstärke
                ),
                ),
                showlegend=True,
            )
        )

    # 6) Layout
    fig.update_geos(
        fitbounds="locations",
        visible=False,
        projection_type="mercator",
        center=dict(lat=51, lon=10),
        projection_scale=7,
    )

    fig.update_layout(
        margin={"r": 20, "t": 60, "l": 20, "b": 20},
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        title=(
            "Mittlere Kennzahlen pro Landkreis (Zustand / Traglast / Alter / Substanz)<br>"
            f"<sup>Zuordnung: {used_coords if spatial_join_ok else 'Name-Match'}; Skala: Min/Max aus gefilterten Brücken</sup>"
        ),
    )

    return fig


# --------------------------------------------------------------------
# 6) Dash-App
# --------------------------------------------------------------------
app = dash.Dash(__name__)

app.layout = html.Div(
    style={"display": "flex", "height": "100vh", "fontFamily": "sans-serif"},
    children=[
        # linkes Menü
        html.Div(
            style={
                "width": "22%",
                "padding": "10px",
                "borderRight": "1px solid #ccc",
                "overflowY": "auto",
            },
            children=[
                html.H3("Ebenen"),
                dcc.Checklist(
                    id="layer-checklist",
                    options=layer_options,
                    value=[o["value"] for o in layer_options],  # alle an
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Flächenanzeige"),
                dcc.RadioItems(
                    id="choropleth-mode",
                    options=choropleth_mode_options,
                    value="zustand",  # Standard
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Baustoffklasse"),
                dcc.Checklist(
                    id="baustoff-checklist",
                    options=baustoff_options,
                    value=[],  # nichts gewählt = alle
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Baujahr Überbau"),
                dcc.RangeSlider(
                    id="year-slider",
                    min=BAUJAHR_MIN,
                    max=BAUJAHR_MAX,
                    step=1,
                    value=[BAUJAHR_MIN, BAUJAHR_MAX],
                    marks=baujahr_marks,
                    allowCross=False,
                    tooltip={"always_visible": False, "placement": "bottom"},
                ),
                html.Hr(),
                html.H3("Straßenart"),
                dcc.Checklist(
                    id="road-checklist",
                    options=road_options,
                    value=[],
                    labelStyle={"display": "block"},
                ),
            ],
        ),

        # rechte Karte
        html.Div(
            style={"width": "78%", "padding": "10px"},
            children=[
                dcc.Graph(
                    id="map-figure",
                    style={"height": "100%"},
                )
            ],
        ),
    ],
)


@app.callback(
    Output("map-figure", "figure"),
    [
        Input("layer-checklist", "value"),
        Input("choropleth-mode", "value"),
        Input("baustoff-checklist", "value"),
        Input("year-slider", "value"),
        Input("road-checklist", "value"),
    ],
)
def update_map(selected_layers, choropleth_mode, selected_baustoff, selected_year_range, selected_road):
    if selected_layers is None:
        selected_layers = []
    if selected_baustoff is None:
        selected_baustoff = []
    if selected_year_range is None:
        selected_year_range = [BAUJAHR_MIN, BAUJAHR_MAX]
    if selected_road is None:
        selected_road = []
    if choropleth_mode is None:
        choropleth_mode = "zustand"

    return build_figure(
        selected_layers,
        selected_baustoff,
        selected_year_range,
        selected_road,
        choropleth_mode,
    )


if __name__ == "__main__":
    app.run(debug=True, port=8063)


In [4]:
# -*- coding: utf-8 -*-
"""
Interactive Germany map (Dash)

Left: filters/toggles
Right: Plotly map

- Bridge points are filtered with AND logic (material, year range, road type)
- County choropleth can show mean condition / load index / age / substance
- Choropleth color range is derived from MIN/MAX of the filtered bridge points (fallback: county means)
- New: toggle whether bridge points are shown at all
- New: toggle whether bridge points are colored by condition (else single color)
"""

import json
import re
import unicodedata
from pathlib import Path

import pandas as pd
import geopandas as gpd
from shapely.geometry import Point

import plotly.graph_objects as go

import dash
from dash import dcc, html
from dash.dependencies import Input, Output

# -----------------------------
# Paths / constants
# -----------------------------
CSV_PATH = Path("data/original_bridge_statistic_germany.csv")
KREISE_GEO_PATH = Path("landkreise_simplify200.geojson")
STATES_GEO_PATH = Path("bundeslaender_simplify0.geojson")

AGE_REF_YEAR = 2024

BAUJAHR_MIN = 1900


# -----------------------------
# Helpers
# -----------------------------
def pick_first(cols, candidates):
    for c in candidates:
        if c in cols:
            return c
    return None


def norm_name(s: str) -> str:
    s = str(s).strip().lower()
    s = s.replace("ä", "ae").replace("ö", "oe").replace("ü", "ue").replace("ß", "ss")
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    s = re.sub(r"\b(landkreis|kreisfreie stadt|kreis|stadt|region|lkr\.?)\b", "", s)
    s = re.sub(r"\s+", " ", s)
    return s.strip()


def to_float_series(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce")


def safe_minmax(series: pd.Series):
    s = pd.to_numeric(series, errors="coerce").dropna()
    if s.empty:
        return None
    vmin = float(s.min())
    vmax = float(s.max())
    if vmin == vmax:
        eps = 1e-6 if vmin == 0 else abs(vmin) * 1e-6
        vmin -= eps
        vmax += eps
    return vmin, vmax


def safe_minmax_with_fallback(series: pd.Series, fallback=(1.0, 4.0)):
    rng = safe_minmax(series)
    return fallback if rng is None else rng


def make_points_from_lonlat(dfin, lon_col="x2", lat_col="y2"):
    lon = to_float_series(dfin[lon_col])
    lat = to_float_series(dfin[lat_col])
    ok = lon.notna() & lat.notna()
    if not ok.any():
        return gpd.GeoDataFrame(dfin.iloc[0:0].copy(), geometry=[], crs=4326)
    g = gpd.GeoDataFrame(
        dfin.loc[ok].copy(),
        geometry=gpd.points_from_xy(lon.loc[ok], lat.loc[ok]),
        crs=4326,
    )
    return g


def make_points_from_utm(dfin, x_col="X", y_col="Y", epsg=25832):
    xx = to_float_series(dfin[x_col])
    yy = to_float_series(dfin[y_col])
    ok = xx.notna() & yy.notna()
    if not ok.any():
        return gpd.GeoDataFrame(dfin.iloc[0:0].copy(), geometry=[], crs=4326)
    g = gpd.GeoDataFrame(
        dfin.loc[ok].copy(),
        geometry=[Point(xy) for xy in zip(xx.loc[ok], yy.loc[ok])],
        crs=epsg,
    ).to_crs(4326)
    return g


def filter_points(pts: gpd.GeoDataFrame, materials, year_range, roads) -> gpd.GeoDataFrame:
    m = pd.Series(True, index=pts.index)

    if materials:
        m &= pts["Baustoffklasse"].isin(materials)

    if year_range and isinstance(year_range, (list, tuple)) and len(year_range) == 2:
        y0, y1 = year_range
        m &= pts["Baujahr"].notna() & (pts["Baujahr"] >= y0) & (pts["Baujahr"] <= y1)

    if roads:
        m &= pts["StraßenartLabel"].isin(roads)

    return pts.loc[m]


# -----------------------------
# Load geometries
# -----------------------------
gdf_kreise = gpd.read_file(KREISE_GEO_PATH)
if gdf_kreise.crs is None:
    gdf_kreise = gdf_kreise.set_crs(4326)
else:
    gdf_kreise = gdf_kreise.to_crs(4326)

colsK = gdf_kreise.columns
KREIS_NAME_COL = pick_first(colsK, ["GEN", "NAME_3", "NAME", "GEN_NAME", "KREIS"])
KREIS_AGS_COL = pick_first(colsK, ["AGS", "AGS_0", "RS", "RS_0", "ID_3"])
if KREIS_NAME_COL is None:
    raise ValueError(f"Could not detect county name column. Columns: {list(colsK)}")

gdf_states = gpd.read_file(STATES_GEO_PATH)
if gdf_states.crs is None:
    gdf_states = gdf_states.set_crs(4326)
else:
    gdf_states = gdf_states.to_crs(4326)

colsS = gdf_states.columns
STATE_NAME_COL = pick_first(colsS, ["GEN", "NAME", "STATE_NAME"])
STATE_ID_COL = pick_first(colsS, ["RS", "RS_0", "AGS", "AGS_0", "ID"])
if STATE_ID_COL is None:
    STATE_ID_COL = STATE_NAME_COL
gdf_states["state_id"] = gdf_states[STATE_ID_COL].astype(str)

# GeoJSON
gplot_json = gdf_kreise[[KREIS_NAME_COL, "geometry"]].copy()
if KREIS_AGS_COL is not None:
    gplot_json["id"] = gdf_kreise[KREIS_AGS_COL].astype(str)
else:
    gplot_json["id"] = gdf_kreise[KREIS_NAME_COL].astype(str)
geojson_kreise = json.loads(gplot_json[["id", "geometry"]].to_json())

states_json = gdf_states[["state_id", "geometry"]].copy()
geojson_states = json.loads(states_json.to_json())

germany_outline = gdf_states.dissolve().reset_index(drop=True)
germany_outline["id"] = "germany_outline"
outline_json = json.loads(germany_outline[["id", "geometry"]].to_json())


# -----------------------------
# Load bridges
# -----------------------------
df = pd.read_csv(CSV_PATH, sep=";", decimal=",", dtype=str, keep_default_na=False)

needed = [
    "Zustandsnote", "Substanzkennzahl",
    "x2", "y2", "X", "Y",
    "Kreis", "Baujahr Überbau", "Traglastindex",
    "Zugeordneter Sachverhalt vereinfacht", "Baustoffklasse", "Bauwerksname"
]
for c in needed:
    if c not in df.columns:
        df[c] = ""

df["Zustandsnote"] = to_float_series(df["Zustandsnote"])
df["Substanzkennzahl"] = to_float_series(df["Substanzkennzahl"])
df["Baujahr Überbau"] = to_float_series(df["Baujahr Überbau"])

# Prefer WGS84 (x2/y2), fallback to UTM32
gdf_pts = make_points_from_lonlat(df, "x2", "y2")
used_coords = "WGS84 (x2/y2)"
if gdf_pts.empty:
    gdf_pts = make_points_from_utm(df, "X", "Y", epsg=25832)
    used_coords = "UTM32 (EPSG:25832)"

gdf_pts["lon"] = gdf_pts.geometry.x
gdf_pts["lat"] = gdf_pts.geometry.y
gdf_pts["Baujahr"] = pd.to_numeric(gdf_pts["Baujahr Überbau"], errors="coerce")

# Optional spatial join (only once)
spatial_join_ok = not gdf_pts.empty
joined = None
if spatial_join_ok:
    right_cols = [KREIS_NAME_COL, "geometry"] if KREIS_AGS_COL is None else [KREIS_NAME_COL, KREIS_AGS_COL, "geometry"]
    joined = gpd.sjoin(gdf_pts, gdf_kreise[right_cols], how="left", predicate="within")

# County name mapping fallback (name-based)
kreise_names = gdf_kreise[[KREIS_NAME_COL]].copy()
kreise_names["__kreis_norm__"] = kreise_names[KREIS_NAME_COL].map(norm_name)
name_map = dict(zip(kreise_names["__kreis_norm__"], gdf_kreise[KREIS_NAME_COL]))

base = joined if spatial_join_ok else df.copy()
base["__br_kreis_norm__"] = base["Kreis"].map(norm_name)
mapped_names = base["__br_kreis_norm__"].map(name_map)
if KREIS_NAME_COL in base.columns:
    base[KREIS_NAME_COL] = base[KREIS_NAME_COL].fillna(mapped_names)
else:
    base[KREIS_NAME_COL] = mapped_names

# Precompute county aggregation
roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
base["Traglastindex_norm"] = base["Traglastindex"].astype(str).str.upper().str.strip().map(roman_map)
base["Alter_Überbau"] = AGE_REF_YEAR - base["Baujahr Überbau"]

agg = (
    base.dropna(subset=[KREIS_NAME_COL])
    .groupby(KREIS_NAME_COL, as_index=False)
    .agg(
        n_bridges=("Bauwerksname", "size"),
        mean_zust=("Zustandsnote", "mean"),
        mean_trag=("Traglastindex_norm", "mean"),
        mean_age=("Alter_Überbau", "mean"),
        mean_substanz=("Substanzkennzahl", "mean"),
        median_zust=("Zustandsnote", "median"),
    )
)

if KREIS_AGS_COL is not None:
    kreise_keys = gdf_kreise[[KREIS_NAME_COL, KREIS_AGS_COL]].drop_duplicates()
    agg = agg.merge(kreise_keys, on=KREIS_NAME_COL, how="left")
    agg = agg.rename(columns={KREIS_NAME_COL: "NAME"})
else:
    agg = agg.rename(columns={KREIS_NAME_COL: "NAME"})

gplot = gdf_kreise.merge(agg, left_on=KREIS_NAME_COL, right_on="NAME", how="left")
if KREIS_AGS_COL is not None and KREIS_AGS_COL in gplot.columns:
    gplot["id"] = gplot[KREIS_AGS_COL].astype(str)
else:
    gplot["id"] = gplot[KREIS_NAME_COL].astype(str)

# Road type labels
mapping_strasse = {"B": "Bundesstraße", "O": "Bundesstraße", "A": "Autobahn"}
raw_road = gdf_pts["Zugeordneter Sachverhalt vereinfacht"].astype(str)
gdf_pts["StraßenartLabel"] = raw_road.map(lambda x: mapping_strasse.get(x, x))

# Hover text (precomputed once)
zust = pd.to_numeric(gdf_pts["Zustandsnote"], errors="coerce")
subs = pd.to_numeric(gdf_pts["Substanzkennzahl"], errors="coerce")
gdf_pts["hover"] = (
    "<b>" + gdf_pts["Bauwerksname"].astype(str) + "</b><br>"
    "Zustandsnote: " + zust.round(2).astype(str) + "<br>"
    "Substanzkennzahl: " + subs.round(2).astype(str) + "<br>"
    "Baustoffklasse: " + gdf_pts["Baustoffklasse"].astype(str) + "<br>"
    "Baujahr Überbau: " + gdf_pts["Baujahr Überbau"].astype(str) + "<br>"
    "Straßenart: " + gdf_pts["StraßenartLabel"].astype(str)
)

# UI options
layer_options = [
    {"label": "Outline", "value": "outline"},
    {"label": "States", "value": "states"},
    {"label": "Choropleth (counties)", "value": "choropleth"},
    {"label": "County borders", "value": "kreise"},
]

baustoff_vals = sorted([x for x in gdf_pts["Baustoffklasse"].dropna().unique() if str(x).strip() != ""])
baustoff_options = [{"label": b, "value": b} for b in baustoff_vals]

road_vals = sorted([x for x in gdf_pts["StraßenartLabel"].dropna().unique() if str(x).strip() != ""])
road_options = [{"label": r, "value": r} for r in road_vals]

BAUJAHR_MAX_DATA = int(gdf_pts["Baujahr"].max()) if gdf_pts["Baujahr"].notna().any() else AGE_REF_YEAR
rest = (BAUJAHR_MAX_DATA - BAUJAHR_MIN) % 25
BAUJAHR_MAX = BAUJAHR_MAX_DATA + (25 - rest) if rest != 0 else BAUJAHR_MAX_DATA
baujahr_marks = {year: str(year) for year in range(BAUJAHR_MIN, BAUJAHR_MAX + 1, 25)}

choropleth_mode_options = [
    {"label": "Condition (mean)", "value": "zustand"},
    {"label": "Load index (mean)", "value": "trag"},
    {"label": "Age (mean)", "value": "alter"},
    {"label": "Substance (mean)", "value": "substanz"},
]


# -----------------------------
# Figure builder
# -----------------------------
def build_figure(
    selected_layers,
    choropleth_mode,
    materials,
    year_range,
    roads,
    show_points,
    color_points,
):
    fig = go.Figure()

    # --- 1) Choropleth FIRST (background fill) ---
    if "choropleth" in selected_layers:
        if choropleth_mode == "trag":
            metric_col = "mean_trag"
            metric_label = "Mean load index"
            fmt = ".2f"
        elif choropleth_mode == "alter":
            metric_col = "mean_age"
            metric_label = "Mean age (years)"
            fmt = ".1f"
        elif choropleth_mode == "substanz":
            metric_col = "mean_substanz"
            metric_label = "Mean substance"
            fmt = ".2f"
        else:
            metric_col = "mean_zust"
            metric_label = "Mean condition"
            fmt = ".2f"

        pts_scale = filter_points(gdf_pts, materials, year_range, roads)

        if choropleth_mode == "trag":
            pts_metric = (
                pts_scale["Traglastindex"].astype(str).str.upper().str.strip().map(roman_map)
            )
        elif choropleth_mode == "alter":
            bauj = pd.to_numeric(pts_scale["Baujahr Überbau"], errors="coerce")
            pts_metric = AGE_REF_YEAR - bauj
        elif choropleth_mode == "substanz":
            pts_metric = pd.to_numeric(pts_scale["Substanzkennzahl"], errors="coerce")
        else:
            pts_metric = pd.to_numeric(pts_scale["Zustandsnote"], errors="coerce")

        rng = safe_minmax(pts_metric)
        if rng is None:
            rng = safe_minmax(gplot[metric_col])

        vmin, vmax = (None, None) if rng is None else rng

        gplot_valid = gplot[gplot[metric_col].notna()]
        gplot_missing = gplot[gplot[metric_col].isna()]

        # Missing data (light grey fill)
        if not gplot_missing.empty:
            fig.add_trace(
                go.Choropleth(
                    geojson=geojson_kreise,
                    locations=gplot_missing["id"],
                    z=[0] * len(gplot_missing),
                    featureidkey="properties.id",
                    name="No data",
                    showlegend=True,
                    showscale=False,
                    colorscale=[[0, "lightgrey"], [1, "lightgrey"]],
                    hoverinfo="skip",
                    marker_line_color="rgba(0,0,0,0)",
                    marker_line_width=0.0,
                    # Important: keep this as a true fill layer
                    marker_opacity=1.0,
                )
            )

        # Valid data (colored fill)
        if not gplot_valid.empty:
            cb_title = metric_label
            if vmin is not None and vmax is not None:
                cb_title = f"{metric_label}<br>min={vmin:{fmt}} / max={vmax:{fmt}}"

            fig.add_trace(
                go.Choropleth(
                    geojson=geojson_kreise,
                    locations=gplot_valid["id"],
                    z=gplot_valid[metric_col],
                    featureidkey="properties.id",
                    name=metric_label + " (counties)",
                    showlegend=True,
                    showscale=True,
                    colorbar_title=cb_title,
                    colorscale="RdYlGn_r",
                    zmin=vmin,
                    zmax=vmax,
                    marker_opacity=1.0,
                    hovertemplate=f"<b>%{{text}}</b><br>{metric_label}: %{{z:{fmt}}}<extra></extra>",
                    text=gplot_valid["NAME"],
                    marker_line_color="rgba(0,0,0,0)",
                    marker_line_width=0.0,
                )
            )

    # --- 2) Optional borders AFTER choropleth (no fill, lines only) ---
    if "states" in selected_layers:
        fig.add_trace(
            go.Choropleth(
                geojson=geojson_states,
                locations=gdf_states["state_id"],
                z=[0] * len(gdf_states),
                featureidkey="properties.state_id",
                name="States",
                showlegend=True,
                showscale=False,
                hoverinfo="skip",
                # Ensure fully transparent fill
                colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]],
                marker_opacity=0.0,  # critical: no fill drawn
                marker_line_color="rgba(190,190,190,0.95)",
                marker_line_width=1.7,
            )
        )

    if "kreise" in selected_layers:
        fig.add_trace(
            go.Choropleth(
                geojson=geojson_kreise,
                locations=gplot["id"],
                z=[0] * len(gplot),
                featureidkey="properties.id",
                name="County borders",
                showlegend=True,
                showscale=False,
                hoverinfo="skip",
                colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]],
                marker_opacity=0.0,  # critical: no fill drawn
                marker_line_color="rgba(170,170,170,0.9)",
                marker_line_width=0.4,
            )
        )

    if "outline" in selected_layers:
        fig.add_trace(
            go.Choropleth(
                geojson=outline_json,
                locations=["germany_outline"],
                z=[0],
                featureidkey="properties.id",
                name="Outline",
                showlegend=True,
                showscale=False,
                hoverinfo="skip",
                colorscale=[[0, "rgba(0,0,0,0)"], [1, "rgba(0,0,0,0)"]],
                marker_opacity=0.0,  # critical: no fill drawn
                marker_line_color="rgba(120,120,120,0.9)",
                marker_line_width=1.2,
            )
        )

    # --- 3) Bridge points LAST (on top) ---
    if show_points:
        pts = filter_points(gdf_pts, materials, year_range, roads)
        if not pts.empty:
            marker = dict(size=4.0, opacity=0.85)

            if color_points:
                pts_zust = pd.to_numeric(pts["Zustandsnote"], errors="coerce")
                zmin_pts, zmax_pts = safe_minmax_with_fallback(pts_zust, fallback=(1.0, 4.0))
                marker.update(
                    dict(
                        color=pts_zust,
                        cmin=zmin_pts,
                        cmax=zmax_pts,
                        colorscale="RdYlGn_r",
                        showscale=True,
                        colorbar=dict(title="Condition<br>(green good / red bad)"),
                        line=dict(color="rgba(0,0,0,0.8)", width=0.05),
                    )
                )
            else:
                marker.update(dict(color="rgba(0,0,0,0.65)"))

            fig.add_trace(
                go.Scattergeo(
                    lon=pts["lon"],
                    lat=pts["lat"],
                    mode="markers",
                    name="Bridges (filtered)",
                    text=pts["hover"],
                    hovertemplate="%{text}<extra></extra>",
                    marker=marker,
                    showlegend=True,
                )
            )

    fig.update_geos(
        fitbounds="locations",
        visible=False,
        projection_type="mercator",
        center=dict(lat=51, lon=10),
        projection_scale=7,
    )

    fig.update_layout(
        margin={"r": 20, "t": 60, "l": 20, "b": 20},
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        title=(
            "County metrics (condition / load / age / substance)<br>"
            f"<sup>Mapping: {used_coords if spatial_join_ok else 'name-based'}; "
            "scale: min/max from filtered bridge points</sup>"
        ),
    )

    return fig



# -----------------------------
# Dash app
# -----------------------------
app = dash.Dash(__name__)

app.layout = html.Div(
    style={"display": "flex", "height": "100vh", "fontFamily": "sans-serif"},
    children=[
        html.Div(
            style={
                "width": "22%",
                "padding": "10px",
                "borderRight": "1px solid #ccc",
                "overflowY": "auto",
            },
            children=[
                html.H3("Layers"),
                dcc.Checklist(
                    id="layer-checklist",
                    options=layer_options,
                    value=[o["value"] for o in layer_options],
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Choropleth metric"),
                dcc.RadioItems(
                    id="choropleth-mode",
                    options=choropleth_mode_options,
                    value="zustand",
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Bridge points"),
                dcc.Checklist(
                    id="points-toggle",
                    options=[
                        {"label": "Show bridge points", "value": "show"},
                        {"label": "Color points by condition", "value": "color"},
                    ],
                    value=["show", "color"],
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Material class"),
                dcc.Checklist(
                    id="baustoff-checklist",
                    options=baustoff_options,
                    value=[],
                    labelStyle={"display": "block"},
                ),
                html.Hr(),
                html.H3("Year (superstructure)"),
                dcc.RangeSlider(
                    id="year-slider",
                    min=BAUJAHR_MIN,
                    max=BAUJAHR_MAX,
                    step=1,
                    value=[BAUJAHR_MIN, BAUJAHR_MAX],
                    marks=baujahr_marks,
                    allowCross=False,
                    tooltip={"always_visible": False, "placement": "bottom"},
                ),
                html.Hr(),
                html.H3("Road type"),
                dcc.Checklist(
                    id="road-checklist",
                    options=road_options,
                    value=[],
                    labelStyle={"display": "block"},
                ),
            ],
        ),
        html.Div(
            style={"width": "78%", "padding": "10px"},
            children=[dcc.Graph(id="map-figure", style={"height": "100%"})],
        ),
    ],
)


@app.callback(
    Output("map-figure", "figure"),
    [
        Input("layer-checklist", "value"),
        Input("choropleth-mode", "value"),
        Input("points-toggle", "value"),
        Input("baustoff-checklist", "value"),
        Input("year-slider", "value"),
        Input("road-checklist", "value"),
    ],
)
def update_map(selected_layers, choropleth_mode, points_toggle, materials, year_range, roads):
    selected_layers = selected_layers or []
    choropleth_mode = choropleth_mode or "zustand"
    points_toggle = points_toggle or []
    materials = materials or []
    year_range = year_range or [BAUJAHR_MIN, BAUJAHR_MAX]
    roads = roads or []

    show_points = "show" in points_toggle
    color_points = "color" in points_toggle

    return build_figure(
        selected_layers=selected_layers,
        choropleth_mode=choropleth_mode,
        materials=materials,
        year_range=year_range,
        roads=roads,
        show_points=show_points,
        color_points=color_points,
    )


if __name__ == "__main__":
    app.run(debug=True, port=8063)
