In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import welch
from scipy.stats import chi2


# ----------------------------
# Linear wave theory helpers
# ----------------------------
def wavenumber_linear_dispersion(f_hz, h_m, g=9.81, niter=30):
    """
    Solve linear dispersion: w^2 = g k tanh(kh)
    Returns k (rad/m). Handles f=0 safely (k=0).
    """
    f_hz = np.asarray(f_hz, dtype=float)
    k = np.zeros_like(f_hz, dtype=float)

    # Only solve for strictly positive frequencies
    pos = f_hz > 0
    if not np.any(pos):
        return k

    fp = f_hz[pos]
    w = 2 * np.pi * fp

    # deep-water initial guess
    kp = (w**2) / g
    kh = kp * h_m

    for _ in range(niter):
        tanh_kh = np.tanh(kh)
        sech2_kh = 1.0 / np.cosh(kh)**2

        F = g * kp * tanh_kh - w**2
        dF = g * tanh_kh + g * kp * h_m * sech2_kh

        kp = kp - F / dF
        kh = kp * h_m

    k[pos] = kp
    return k



def pressure_head_to_eta_psd_multiplier(f_hz, h_m, z_m=-1.0, g=9.81, kmaxh=12.0):
    """
    Compute the spectral multiplier M(f) such that:
        S_etaeta(f) = S_hh(f) * M(f)
    where S_hh is the PSD of pressure head h_p = p'/(rho g) in m^2/Hz.

    Linear theory transfer function:
        h_p(z) = eta * cosh(k(z+h))/cosh(kh)
    => eta = h_p / R, where R = cosh(k(z+h))/cosh(kh)

    Therefore:
        M(f) = (1/R)^2 = (cosh(kh)/cosh(k(z+h)))^2

    Parameters
    ----------
    h_m : mean water depth (m)
    z_m : sensor elevation relative to still water level (m), negative below surface.
          For bottom-mounted pressure at seabed: z = -h.
          If port is a small height above bed: z = -(h - z_above_bed).
    kmaxh : cap k*h to limit high-frequency amplification.
    """
    f_hz = np.asarray(f_hz, dtype=float)

    k = wavenumber_linear_dispersion(f_hz, h_m, g=g)
    kh = k * h_m

    # cap kh to avoid huge amplification at high f
    kh_c = np.minimum(kh, kmaxh)
    k_c = kh_c / h_m

    # compute kz term with capped k
    kzph = k_c * (z_m + h_m)  # (z + h) is height above bed
    R = np.cosh(kzph) / np.cosh(kh_c)

    # multiplier for PSD
    M = (1.0 / R)**2

    return M, kh_c


def add_period_axis(ax, period_ticks_s=(0.5, 1, 2, 5, 10, 20, 50)):
    ax_top = ax.twiny()
    ax_top.set_xscale("log")
    fmin, fmax = ax.get_xlim()
    ax_top.set_xlim(fmin, fmax)

    period_ticks_s = np.asarray(period_ticks_s, dtype=float)
    f_ticks = 1.0 / period_ticks_s
    mask = (f_ticks >= fmin) & (f_ticks <= fmax)

    ax_top.set_xticks(f_ticks[mask])
    ax_top.set_xticklabels([f"{p:g}" for p in period_ticks_s[mask]])
    ax_top.set_xlabel("Period (s)")
    return ax_top


def bulk_stats_from_spectrum(f, S, fmin=None, fmax=None):
    f = np.asarray(f, dtype=float)
    S = np.asarray(S, dtype=float)

    mask = np.isfinite(f) & np.isfinite(S) & (f > 0)
    if fmin is not None:
        mask &= (f >= fmin)
    if fmax is not None:
        mask &= (f <= fmax)

    ff = f[mask]
    SS = S[mask]
    if ff.size < 3:
        return dict(m0=np.nan, Hm0=np.nan, fp=np.nan, Tp=np.nan, m1=np.nan, Tm0=np.nan)

    m0 = np.trapezoid(SS, ff)
    m1 = np.trapezoid(ff * SS, ff)

    Hm0 = 4.0 * np.sqrt(m0) if m0 > 0 else np.nan
    fp = ff[np.argmax(SS)]
    Tp = 1.0 / fp if fp > 0 else np.nan
    Tm0 = (m0 / m1) if (m1 > 0) else np.nan

    return dict(m0=m0, Hm0=Hm0, fp=fp, Tp=Tp, m1=m1, Tm0=Tm0)


# ----------------------------
# Clean LWT-only analyzer
# ----------------------------
def analyze_experiment_lwt_pressure(
    site_dfs: dict,
    color_map: dict,
    experiment_name: str,
    fs: float = 8.0,
    seg_seconds: float = 120.0,
    overlap_frac: float = 0.5,
    fmin: float = 0.02,
    fmax: float = 2.5,
    pressure_col: str = "Pressure",
    depth_col: str = "Depth",
    pressure_units: str = "dbar",   # "dbar" or "Pa"
    rho: float = 1025.0,
    g: float = 9.81,
    sensor_height_above_bed_m: float = 0.0,  # 0 => seabed; use e.g. 0.05 if port is 5 cm above bed
    kmaxh: float = 12.0, # This is probably too much
    show_ci: bool = True,
    plot_correction: bool = True,
):
    """
    LWT-only processing:
      Pressure -> pressure head h_p = p'/(rho g) -> eta(t) via LWT response -> Welch PSD of eta.
    Uses mean depth for each site and a capped response (kmaxh) to avoid high-f noise blow-up.

    Returns:
      spec: dict with f, S_eta, Slo, Shi, plus correction curves
      stats_df: bulk stats table
    """

    nperseg = int(seg_seconds * fs)
    noverlap = int(overlap_frac * nperseg)

    spec = {}
    stats_rows = []

    # ---- First pass: compute spectra and store correction curves ----
    for site, df in site_dfs.items():

        p = df[pressure_col].to_numpy(dtype=float)
        p = p[np.isfinite(p)]

        # remove mean (atmospheric offset and static head do not matter for waves)
        pprime = p - np.mean(p)

        # convert pressure to Pa
        if pressure_units.lower() == "dbar":
            pprime_pa = pprime * 1.0e4
        elif pressure_units.lower() == "pa":
            pprime_pa = pprime
        else:
            raise ValueError("pressure_units must be 'dbar' or 'Pa'")

        # pressure head time series (m)
        head = pprime_pa / (rho * g)

        # Welch on pressure head to get S_hh (m^2/Hz)
        f, S_hh = welch(
            head,
            fs=fs,
            window="hann",
            nperseg=nperseg,
            noverlap=noverlap,
            detrend="constant",
            scaling="density",
            return_onesided=True
        )

        # DOF and CI (for S_hh; we will propagate by multiplying by M)
        step = nperseg - noverlap
        K = 1 + (len(head) - nperseg) // step
        K = int(max(K, 1))
        dof = 2 * K

        alpha = 0.05
        Slo_hh = (dof * S_hh) / chi2.ppf(1 - alpha / 2, dof)
        Shi_hh = (dof * S_hh) / chi2.ppf(alpha / 2, dof)

        # Mean depth and sensor elevation z
        hbar = float(np.nanmean(df[depth_col].to_numpy(dtype=float)))
        z = -(hbar - sensor_height_above_bed_m)  # relative to surface; negative below surface

        # LWT multiplier: S_etaeta = S_hh * M(f)
        M, kh_c = pressure_head_to_eta_psd_multiplier(f, hbar, z_m=z, g=g, kmaxh=kmaxh)

        S_eta = S_hh * M
        Slo = Slo_hh * M
        Shi = Shi_hh * M

        # Bulk stats on eta spectrum
        stats = bulk_stats_from_spectrum(f, S_eta, fmin=fmin, fmax=fmax)

        spec[site] = dict(
            f=f, S=S_eta, Slo=Slo, Shi=Shi,
            dof=dof, K=K,
            hbar=hbar, z=z,
            M=M, kh=kh_c
        )

        stats_rows.append({
            "experiment": experiment_name,
            "site": site,
            "h_mean (m)": hbar,
            "Hm0 (m)": stats["Hm0"],
            "Tp (s)": stats["Tp"],
            "Tm0 (s)": stats["Tm0"],
            "segments_K": K,
            "dof": dof
        })

    stats_df = pd.DataFrame(stats_rows)

    # ---- Plot 1: overlay eta spectra (similar look to your prior plots) ----
    fig, ax = plt.subplots(figsize=(9, 6))

    for site in site_dfs.keys():
        out = spec[site]
        f, S = out["f"], out["S"]
        mask = (f >= fmin) & (f <= fmax)

        ax.semilogx(f[mask], S[mask], color=color_map.get(site), label=site)

        if show_ci:
            ax.fill_between(
                f[mask],
                out["Slo"][mask],
                out["Shi"][mask],
                alpha=0.25,
                linewidth=0,
                color=color_map.get(site),
            )

    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel(r"Variance spectrum $S_{\eta\eta}(f)$ (m$^2$/Hz)")
    ax.set_title(
        f"{experiment_name}: LWT-corrected surface elevation spectra from Pressure\n"
        f"{seg_seconds:.0f}-s segments, {overlap_frac*100:.0f}% overlap, kmaxh={kmaxh:g}"
    )
    ax.grid(True, which="both")
    ax.legend()
    add_period_axis(ax, period_ticks_s=(0.5, 1, 2, 5, 10, 20, 50))

    fig.tight_layout()
    plt.show()

    # ---- Plot 2: depth-dependent correction vs frequency ----
    if plot_correction:
        fig2, ax2 = plt.subplots(figsize=(9, 5))

        for site in site_dfs.keys():
            out = spec[site]
            f = out["f"]
            mask = (f >= fmin) & (f <= fmax)

            # Plot the PSD multiplier M(f) = (cosh(kh)/cosh(k(z+h)))^2
            ax2.semilogx(
                f[mask],
                out["M"][mask],
                color=color_map.get(site),
                label=f"{site} (h={out['hbar']:.2f} m)"
            )

        ax2.set_xlabel("Frequency (Hz)")
        ax2.set_ylabel(r"LWT correction multiplier $M(f)$ (unitless)")
        ax2.set_title(f"{experiment_name}: Linear wave theory correction vs frequency")
        ax2.grid(True, which="both")
        ax2.legend()
        add_period_axis(ax2, period_ticks_s=(0.5, 1, 2, 5, 10, 20, 50))

        fig2.tight_layout()
        plt.show()

    return spec, stats_df


In [None]:
# Read processed data
data_dir = "C:/crs/proj/FA_science/"
df_blue1r = pd.read_csv( data_dir+"blue1r.csv", index_col=0, parse_dates=True )
df_yellow1r = pd.read_csv( data_dir+"yellow1r.csv", index_col=0, parse_dates=True )
df_red1r = pd.read_csv( data_dir+"red1r.csv", index_col=0, parse_dates=True )
df_green1r = pd.read_csv( data_dir+"green1r.csv", index_col=0, parse_dates=True )

df_blue2r = pd.read_csv( data_dir+"blue2r.csv", index_col=0, parse_dates=True )
df_yellow2r = pd.read_csv( data_dir+"yellow2r.csv", index_col=0, parse_dates=True )
df_red2r = pd.read_csv( data_dir+"red2r.csv", index_col=0, parse_dates=True )
df_green2r = pd.read_csv( data_dir+"green2r.csv", index_col=0, parse_dates=True )

## Process exp. 1
exp1_sites = {'Offshore': df_blue1r, 'Rocks': df_red1r, 'Sand': df_yellow1r, 'Grass': df_green1r }
color_map1  = {
    "Offshore": "b",       # blue
    "Rocks":    "r",       # red
    "Sand":     "gold",    # yellow
    "Grass":    "g",       # green
}

spec1p, stats1p = analyze_experiment_lwt_pressure(
    exp1_sites,
    color_map=color_map1,
    experiment_name="Experiment 1",
    pressure_units="dbar",       # change to "Pa" if needed
    sensor_height_above_bed_m=0.03,  # set if your port is above bed
    fmax = 2.5,
    kmaxh=6.0,
)

# Process exp. 2
exp2_sites = {'Offshore': df_blue2r, 'Mid': df_red2r, 'Shallow': df_yellow2r, 'Grass': df_green2r }
color_map2  = {
    "Offshore": "b",       # blue
    "Mid":    "r",       # red
    "Shallow":     "gold",    # yellow
    "Grass":    "g",       # green
}

spec2p, stats2p = analyze_experiment_lwt_pressure(
    exp2_sites,
    color_map=color_map2,
    experiment_name="Experiment 2",
    pressure_units="dbar",       # change to "Pa" if needed
    sensor_height_above_bed_m=0.03,  # set if your port is above bed
    kmaxh=6.0,
)
stats2p
