In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from glob import glob
import os
from typing import Callable

import geopandas as gpd
import pandas as pd
from pyproj import Geod
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import rioxarray
import rasterio
from shapely.geometry import LineString
from snail.core.intersections import split_linestring
from snail.core.intersections import get_cell_indices as get_cell_indicies_of_midpoint
from scipy.integrate import simpson

from utils import aqueduct_rp

In [None]:
data_dir = "data"
country_iso = "bgd"

In [None]:
# flood hazard data to use, pulled from the autopkg API
epoch = 2050
scenario = "rcp4p5"
raster_paths = glob(f"data/{country_iso}/wri_aqueduct/*{scenario}*{epoch}*.tif")
raster_paths = sorted(raster_paths, key=aqueduct_rp, reverse=True)

network = gpd.read_file(f"data/{country_iso}/gri_osm/{country_iso}.gpkg")

In [None]:
def check_raster_grid_consistent(raster_paths: list[str]) -> None:
    """
    Check a set of rasters are on the same grid.
    """
    if len(raster_paths) > 1:
        reference, *others = raster_paths

        with rasterio.open(reference) as dataset:
            raster_width = dataset.width
            raster_height = dataset.height
            raster_transform = list(dataset.transform)

        # check all raster files use the same grid
        for raster_path in others:
            with rasterio.open(raster_path) as raster:
                if (
                    raster_width != raster.width
                    or raster_height != raster.height
                    or raster_transform != list(raster.transform)
                ):
                    raise AttributeError(
                        (
                            f"Raster attribute mismatch in file {raster_path}:\n"
                            f"Height: expected={raster_height}; actual={raster.height}\n"
                            f"Width: expected={raster_width}; actual={raster.width}\n"
                            f"Transform equal? {'True' if list(raster.transform) == raster_transform else 'False'}"
                        )
                    )

In [None]:
def split_linestrings(features: gpd.GeoDataFrame, raster: rasterio.io.DatasetReader) -> gpd.GeoDataFrame:
    """
    Split feature linestrings on a raster grid
    """
    
    if set(features.geometry.type) != {"LineString"}:
        raise ValueError("Can only split LineString geometries")
    
    all_splits = []
    all_indicies = []
    for edge in features.itertuples():
        split_geoms = split_linestring(
            edge.geometry,
            raster.width,
            raster.height,
            list(raster.transform),
        )
        all_splits.extend(split_geoms)
        all_indicies.extend([edge.Index] * len(split_geoms))

    return gpd.GeoDataFrame({"original_index": all_indicies, "geometry": all_splits})

In [None]:
def cell_indicies_assigner(raster: rasterio.io.DatasetReader) -> Callable:
    """
    Given an open raster, return a function that can check a geometry against the
    raster grid and return grid cell indicies for that geometry.
    """
    
    def cell_indicies_of_split_geometry(geometry, *args, **kwargs) -> pd.Series:
        """
        Given a geometry, find the cell index (i, j) of its midpoint for the
        enclosing raster parameters.

        N.B. There is no checking whether a geometry spans more than one cell.
        """

        # integer indicies
        i, j = get_cell_indicies_of_midpoint(geometry, raster.height, raster.width, raster.transform)

        # die if we're out of bounds somehow
        assert 0 <= i < raster.width
        assert 0 <= j < raster.height

        # return a series with labels so we can unpack neatly into two dataframe columns
        return pd.Series(index=("raster_i", "raster_j"), data=[i, j])
   
    return cell_indicies_of_split_geometry

In [None]:
# filter to linestrings (edges)
lines = network[network.geometry.type == "LineString"]

# error if grids not consistent
check_raster_grid_consistent(raster_paths)

# split edges on raster grid
raster_path, *other_raster_paths = raster_paths
raster = rasterio.open(raster_path)
splits = split_linestrings(lines, raster)

# calculate split edge lengths
geod = Geod(ellps="WGS84")
meters_per_km = 1_000
splits["length_km"] = splits.geometry.apply(geod.geometry_length) / meters_per_km

# which cell is each split edge in?
assigner = cell_indicies_assigner(raster)
raster_indicies = splits.geometry.apply(assigner)

# join raster indicies to geometries with shared index
splits_with_raster_indicies = splits.join(raster_indicies)

In [None]:
# map raster indicies as visual check
f, (ax_i, ax_j) = plt.subplots(ncols=2)
splits_with_raster_indicies.plot(ax=ax_i, column="raster_i", cmap="viridis", legend=True)
ax_i.set_title("raster_i")
splits_with_raster_indicies.plot(ax=ax_j, column="raster_j", cmap="cubehelix", legend=True)
ax_j.set_title("raster_j")
f.tight_layout()

In [None]:
def raster_lookup(df: pd.DataFrame, fname: str, band_number: int=1) -> pd.Series:
    """
    For each split geometry, lookup the relevant raster value. Cell indicies
    must have been previously calculated and stored as "raster_i" and "raster_j".

    Args:
        df (pd.DataFrame): Table of features, each with cell indicies pertaining
            to relevant raster pixel. Indicies must be stored under columns with
            names referenced by fields.RASTER_I and fields.RASTER_J
        fname (str): Filename of raster file to read data from
        band_number (int): Which band of the raster file to read

    Returns:
        pd.Series: Series of raster values, with same row indexing as df.
    """

    with rasterio.open(fname) as dataset:

        band_data: np.ndarray = dataset.read(band_number)
            
        # set non-positive values to NaN
        band_data[band_data < 1E-6] = np.nan

        # 2D numpy indexing is j, i (i.e. row, column)
        return pd.Series(index=df.index, data=band_data[df["raster_j"], df["raster_i"]])

In [None]:
for path in raster_paths:
    splits_with_raster_indicies[f"rp-{aqueduct_rp(path)}"] = raster_lookup(splits_with_raster_indicies, path)
    
hazard_intensities = splits_with_raster_indicies
hazard_intensities.describe()

In [None]:
def logistic_min(x: float | np.ndarray, L: float, m: float, k: float, x_0: float) -> float | np.ndarray:
    """
    Logistic function with a minimum value, m.

    Args:
        x: Input values
        L: Maximum output value
        m: Minimum output value
        k: Steepness parameter
        x_0: Location of sigmoid centre in x

    Returns:
        Output values
    """

    return m + (L - m) / (1 + np.exp(-k * (x - x_0)))

# define a damage function
damage_curve = lambda x: logistic_min(x, 1, 0, 2, 2)

# have a look at it
f, ax = plt.subplots()
x = np.linspace(0, 5, 20)
ax.scatter(x, damage_curve(x))
ax.set_xlabel("Flood depth [meters]")
ax.set_ylabel("Damage fraction")
ax.set_title("Damage function")
ax.grid()

In [None]:
# calculate how badly each split edge is damaged by the flooding
damage_fractions = hazard_intensities.copy()
hazard_cols = [col for col in hazard_intensities.columns if col.startswith("rp-")]
damage_fractions[hazard_cols] = damage_fractions[hazard_cols].applymap(damage_curve)

In [None]:
# calculate the cost of damage
reconstruction_cost_currency_per_km = 1E4

damage_cost = damage_fractions.copy()
for col in hazard_cols:
    damage_cost[col] = damage_cost[col] * damage_cost.length_km * reconstruction_cost_currency_per_km

grouped_damage_cost = damage_cost[hazard_cols].groupby(damage_cost.original_index).sum()
probability_per_year = 1 / np.array([int(col.replace("rp-", "")) for col in hazard_cols])

damage_probability_curve = grouped_damage_cost.copy()
damage_probability_curve.columns = probability_per_year

In [None]:
# plot the damage-probability curve
f, ax = plt.subplots()
damage_probability_sum = damage_probability_curve.sum()
damage_probability_sum.plot(ax=ax)
ax.grid()
ax.set_xlabel("Probability per given year")
ax.set_ylabel("Damage cost [currency]")
ax.set_title(f"Damage-probability curve\n{scenario.upper()} {epoch}")

In [None]:
EAD = lines[["geometry"]].copy()
EAD["ead"] = simpson(grouped_damage_cost, x=probability_per_year, axis=1)
f, ax = plt.subplots(figsize=(10,10))
EAD.plot(
    ax=ax,
    column="ead",
    legend=True,
    cmap="RdPu",
    norm=matplotlib.colors.LogNorm(vmin=1E0, vmax=EAD.ead.max())
)
ax.grid()
ax.set_title(f"Expected Damages [currency per annum]\nTotal: {EAD.ead.sum():.2E}")