In [None]:
import sys

sys.path.insert(0, "../src")

from coastmonitor.io.drive_config import configure_instance

configure_instance(branch="dev")
import dask

dask.config.set({"datatframe.query-planning": False})
import logging
import os
import pathlib

import duckdb
import geopandas as gpd
import hvplot.pandas
import pandas as pd
import pystac
from dotenv import load_dotenv

from coastmonitor.io.utils import read_items_extent
from coastmonitor.query_engine import HREFQueryEngine, STACQueryEngine

load_dotenv(override=True)

# NOTE: access tokens to the data are available upon request.
sas_token = os.getenv("AZURE_STORAGE_SAS_TOKEN")
account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
storage_options = {"account_name": account_name, "credential": sas_token}

# These are the URL's to the STAC catalog that we can use to efficiently index the data
COCLICO_STAC_URL = "https://coclico.blob.core.windows.net/stac/v1/catalog.json"

# Global Coastal Transect System (publicly available and in review)
GCTS_COLLECTION_NAME = "gcts"

# Global Coastal Transect Repository (unreleased; access keys provided upon request). This dataset consists
# of GCTS + several other characteristics, such as intersection distance to nearest coastline.
GCTR_COLLECTION_NAME = "gctr"

# ShorelineMonitor Raw Series (unreleased; access keys provided upon request). This dataset consists
# ShorelineMonitor Shorlines that are mapped onto the Global Coastal Transect System (Raw Series) that
# have a wide range of additional statistics used to filter out the primary, high-quality observations.
SM_COLLECTION_NAME = "shorelinemonitor-raw-series"

## Read the STAC collections

In [None]:
coclico_catalog = pystac.Catalog.from_file(COCLICO_STAC_URL)
sm_collection = coclico_catalog.get_child(SM_COLLECTION_NAME)
gcts_collection = coclico_catalog.get_child(GCTS_COLLECTION_NAME)

## Show the spatial extents of both collections

In [None]:
sm_extents = read_items_extent(sm_collection)
gcts_extents = read_items_extent(gcts_collection)
sm_extents[["geometry"]].explore()

## Create a interactive map that we use to define our region of interest

In [None]:
from ipyleaflet import Map, basemaps

m = Map(basemap=basemaps.Esri.WorldImagery, scroll_wheel_zoom=True)
m.center = 53.4, 5.4
m.zoom = 11
m.layout.height = "800px"
m

In [None]:
# NOTE: these coordiantes are extracted from the interactive map above
minx, miny, maxx, maxy = m.west, m.south, m.east, m.north

## Create a DuckDB query engine to retrieve data from cloud storage

In [None]:
shoreline_engine = STACQueryEngine(
    stac_collection=sm_collection,
    storage_backend="azure",
)

In [None]:
shorelines = shoreline_engine.get_data_within_bbox(minx, miny, maxx, maxy)

In [None]:
transects_engine = STACQueryEngine(
    stac_collection=gcts_collection, storage_backend="azure"
)
transects = transects_engine.get_data_within_bbox(minx, miny, maxx, maxy)

In [None]:
s = shorelines.loc[shorelines["shoreline_position"].isna()].copy()
s = shorelines.copy()
last_obs = (
    s.sort_values(by=["time"])
    .groupby("tr_name")["shoreline_chainage"]
    .last()
    .rename("last_obs")
)
s = s.merge(last_obs, on="tr_name", how="left")
s["shoreline_position"] = s["shoreline_chainage"] - s["last_obs"]
# s["shoreline_position"] = s["shoreline_chainage"] - s["last_obs"]

In [None]:
(~shorelines.obs_is_primary).sum()

In [None]:
(~shorelines.loc[shorelines["shoreline_position"].isna()].obs_is_primary).sum()

In [None]:
s["last_obs"].max(

In [None]:
s["shoreline_position"]

In [None]:
last_obs

In [None]:
from typing import List, Tuple

import pandas as pd
from pandas import DataFrame


def flag_obs_ht_max_step_change(
    df: DataFrame,
    max_step_change: float = 50,
    max_n_step_changes: int = 4,
    max_year_interval: int = 10,
) -> DataFrame:
    """
    Detects significant step changes in shoreline positions across different transects within a DataFrame.
    A step change is identified based on the difference between consecutive measurements exceeding a specified threshold.

    The function also flags entire transects as 'unsteady' if the number of step changes exceeds a certain threshold.

    Args:
        df (DataFrame): A pandas DataFrame containing shoreline position data. The DataFrame must include
                        the columns 'time', 'tr_name', 'shoreline_position', and 'geometry'.
        max_step_change (float): Threshold for detecting significant step changes. Default is 50.
        max_n_step_changes (int): Maximum number of step changes allowed per transect before flagging as 'unsteady'. Default is 4.
        max_year_interval (int): Maximum year difference for considering a step change significant. Default is 10.

    Returns:
        DataFrame: A modified copy of the input DataFrame with additional columns indicating detected
                   step changes ('obs_ht_max_step_change') and unsteady transects ('tr_is_unsteady').
    """
    # Copy relevant columns
    df = df[["time", "tr_name", "shoreline_position", "geometry"]].copy()

    # Calculate differences and year shifts
    df["backward_diff"] = df.groupby("tr_name")["shoreline_position"].diff()
    df["forward_diff"] = df["backward_diff"].shift(-1)
    df["year"] = df["time"].dt.year
    df["dt_backward"] = df.groupby("tr_name")["year"].diff()
    df["dt_forward"] = df["dt_backward"].shift(-1)

    # Detect step changes in the middle of the time series
    df["mid_step_change"] = (
        (df["backward_diff"].abs() > max_step_change)
        & (df["forward_diff"].abs() > max_step_change)
        & (df["backward_diff"] * df["forward_diff"] < 0)
        & (df["dt_backward"] < max_year_interval)
        & (df["dt_forward"] < max_year_interval)
    )

    # Group by the observations per transect
    g = df.groupby("tr_name")[["tr_name", "backward_diff", "forward_diff"]]

    # Get the first observation
    first = g.nth(0).reset_index(drop=False).set_index("tr_name")
    # First observation is a step change if the first difference exceeds max_step_change and is followed by a small difference
    first["first_step_change"] = (first["forward_diff"].abs() > max_step_change) & (
        g.nth(1).set_index("tr_name")["forward_diff"] < max_step_change
    )
    # Merge the flag to the primary DataFrame
    df["first_step_change"] = first.set_index("index")["first_step_change"]

    # Get the last observation
    last = g.tail(1).reset_index(drop=False).set_index("tr_name")
    return last
    # Last observation is a step change if the last difference exceeds max_step_change and is preceded by a small difference
    last["last_step_change"] = (last["backward_diff"].abs() > max_step_change) & (
        g.tail(2).iloc[::2].set_index("tr_name")["backward_diff"].abs()
        < max_step_change
    )
    # Merge the flag to the primary DataFrame
    df["last_step_change"] = last.set_index("index")["last_step_change"]

    # Combine mid, first, and last step change flags
    df["obs_ht_max_step_change"] = (
        df["mid_step_change"] | df["first_step_change"] | df["last_step_change"]
    )

    # Drop intermediate columns
    df = df.drop(
        columns=[
            "backward_diff",
            "forward_diff",
            "mid_step_change",
            "first_step_change",
            "last_step_change",
            "dt_backward",
            "dt_forward",
            "year",
        ]
    )

    # Count the number of step changes per transect
    tr_step_change = df.groupby("tr_name")["obs_ht_max_step_change"].sum().reset_index()
    "obs_ht_max_step_change"

    # Flag transects that have step changes exceeding the allowed number of step changes
    tr_step_change["tr_is_unsteady"] = (
        tr_step_change["obs_ht_max_step_change"] >= max_n_step_changes
    )

    # Merge "tr_is_unsteady" to the DataFrame
    df = df.merge(tr_step_change[["tr_name", "tr_is_unsteady"]], on="tr_name")

    return df


def clean_raw_shorelinemonitor_series(
    df: DataFrame,
    columns: List[str] = ["time", "tr_name", "shoreline_position", "geometry"],
    sinuosity_threshold: float = 10,
    mdn_offset_multiplier: float = 3,
    min_obs_count: int = 5,
    max_step_change: float = 50,
    max_n_step_changes: int = 4,
    max_year_interval: int = 10,
) -> DataFrame:
    """
    Cleans and filters shoreline position data based on specified criteria.

    Args:
        df (DataFrame): Raw shoreline position data with required columns:
                        'shoreline_sinuosity', 'is_shoal', 'obs_is_primary',
                        'tr_is_qa', 'mdn_offset', 'tr_stdev', 'obs_is_outlier',
                        'obs_count', 'time', 'tr_name', 'shoreline_position', 'geometry'.
        columns (List[str]): List of columns to include in the cleaned DataFrame. Default is ["time", "tr_name", "shoreline_position", "geometry"].
        sinuosity_threshold (float): Threshold for sinuosity. Default is 10.
        mdn_offset_multiplier (float): Multiplier for the standard deviation to filter based on median offset. Default is 3.
        min_obs_count (int): Minimum observation count per transect. Default is 5.
        max_step_change (float): Threshold for detecting significant step changes. Default is 50.
        max_n_step_changes (int): Maximum number of step changes allowed per transect before flagging as 'unsteady'. Default is 4.
        max_year_interval (int): Maximum year difference for considering a step change significant. Default is 10.

    Returns:
        DataFrame: Cleaned shoreline positions with selected columns and recalculated observation count per transect.
    """
    # Filtering criteria for clean shoreline positions
    df = df[
        (df["shoreline_sinuosity"] < sinuosity_threshold)
        & (~df["is_shoal"])
        & (df["obs_is_primary"])
        & (df["tr_is_qa"])
        & (df["mdn_offset"] < mdn_offset_multiplier * df["tr_stdev"])
        & (df["obs_count"] >= min_obs_count)
        & (df["obs_is_outlier"] != 1)
    ].copy()

    # Detect and flag step changes
    df = flag_obs_ht_max_step_change(
        df,
        max_step_change=max_step_change,
        max_n_step_changes=max_n_step_changes,
        max_year_interval=max_year_interval,
    )

    return df

    # # DEBUG
    # df = df.loc[(~df["tr_is_unsteady"]) & (~df["obs_ht_max_step_change"])]

    # # Count the clean observations on each transect
    # obs_count = (
    #     df.groupby("tr_name")["shoreline_position"]
    #     .count()
    #     .rename("obs_count")
    #     .reset_index()
    # )
    # df = df.merge(obs_count, on="tr_name")

    # # # Organize the clean shoreline position DataFrame
    # # df = (
    # #     df.rename(columns={"shoreline_position": "shoreline_position"})
    # #     .reset_index(drop=True)
    # # )

    # return df[columns]


df_clean = clean_raw_shorelinemonitor_series(
    shorelines,
    columns=["time", "tr_name", "shoreline_position", "geometry", "obs_count"],
    sinuosity_threshold=10,
    mdn_offset_multiplier=3,
    min_obs_count=5,
    max_step_change=50,
    max_n_step_changes=4,
    max_year_interval=10,
)

In [None]:
df_clean

In [None]:
import numpy as np
from scipy import linalg, signal
from tqdm import tqdm


def ols_AC(group):
    """
    Performs ordinary least squares (OLS) regression to find the slope and intercept
    of shoreline positions as a function of time for a given group of data.

    Parameters:
        group (DataFrame): A pandas DataFrame containing the columns 'shoreline_position' and 'time',
                           where 'time' is a datetime object.

    Returns:
        tuple:
            - A tuple containing the intercept and slope of the regression line (p),
            - The sum of the squared residuals of the regression (res).

    Notes:
        The design matrix is constructed with a constant term and a linear term for the year extracted
        from the 'time' datetime object.
    """
    y = group.shoreline_position.values
    x = group.time.dt.year.values  # - 1984)

    # create design matrix
    M = x[:, np.newaxis] ** [0, 1]

    # calculate least square solution
    p, res, _, _ = linalg.lstsq(M, y)

    return p, res


def all_AC(shoreline_positions):
    """
    Applies OLS regression across groups of shoreline position data, each group identified by 'tr_name',
    and aggregates the results into a DataFrame.

    This function iterates over each transect name group, performs OLS regression if the group has more than
    two data points, and collects the regression coefficients and residuals.

    Parameters:
        shoreline_positions (DataFrame): A pandas DataFrame containing 'shoreline_position', 'time', and 'tr_name',
                                         where 'time' must be a datetime object and 'tr_name' is the identifier
                                         for each group.

    Returns:
        DataFrame: A DataFrame with columns 'tr_name' for the transect names, 'intercept' and 'rate' for the
                   OLS regression coefficients, and 'residues' for the sum of squared residuals of each regression.
    """
    names = []
    intercepts = []
    slopes = []
    residues = []

    for name, group in tqdm(
        shoreline_positions.groupby("tr_name"),
        total=shoreline_positions.tr_name.unique().size,
    ):
        # for name, group in shoreline_positions.groupby("tr_name"):
        if group.obs_count.iloc[0] > 5:
            p, res = ols_AC(group)

            names.append(name)
            intercepts.append(p[0])
            slopes.append(p[1])
            residues.append(res)

    ols_ = pd.DataFrame(
        {
            "tr_name": names,
            "intercept": intercepts,
            "rate": slopes,
            "residues": residues,
        }
    )

    return ols_


ambient_change = all_AC(df_clean)
ambient_change = ambient_change.merge(
    transects[["tr_name", "geometry"]],
    on="tr_name",
    how="left",
)
ambient_change = gpd.GeoDataFrame(ambient_change, crs=4326)[
    ["rate", "geometry", "tr_name"]
]

In [None]:
plot = ambient_change.hvplot(
    kind="line",
    # x="alongshore_dist_km",
    y=["rate"],
    # groupby=["coastline"],
    xlabel="Alongshore Distance [km]",
    color="blue",
    alpha=0.2,
)

In [None]:
import colorcet as cc
import holoviews as hv

hv.extension("bokeh")

ambient_change.hvplot(
    geo=True,
    tiles="ESRI",
    color="rate",
    line_width=3,
    title="Transects Colored by Rate",
    width=800,
    colorbar=True,
    cnorm="linear",
    cmap=cc.CET_D3[::-1],
)

In [None]:
{}

In [None]:
gpd.GeoSeries.from_xy()