In [None]:
from importlib import reload

%load_ext autoreload
%autoreload 2

In [None]:
import ipywidgets
from glob import glob
from os import walk
import os
import numpy as np

from scipy.ndimage import gaussian_filter1d

from atl_module.io.atl03_netcdf_loading import get_beams, load_beam_array_ncds
from atl_module.bathymetry_extraction import icesat_bathymetry
from atl_module.geospatial_utils.raster_interaction import query_raster
from atl_module.error_calc import icesat_error_rms_mae

from bokeh.io import output_notebook
from bokeh.palettes import Spectral5
from bokeh.plotting import figure, show
from bokeh.transform import factor_cmap

output_notebook()
TOOLS = "hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,undo,redo,reset,tap,save,box_select,poly_select,lasso_select,"

# TODO abstract the graph into smaller functions, then combine into one "evaluate transect" function that can be run as callback with the widgets

In [None]:
%cd ..

In [None]:
# little weird but it works :)

# create emtpy globals to update
site_granules = ""
beam = ""
granchooser = ipywidgets.Select()
beamchooser = ipywidgets.Select()

In [None]:
outlist = list(walk("../data/test_sites"))
subdirlist = outlist[0][1]
sitechooser = ipywidgets.Select(options=subdirlist, description="Choose a test site")

In [None]:
def spgaussfilter(arrayin):
    gaussout = gaussian_filter1d(arrayin, sigma=20)
    return gaussout[int(len(gaussout) / 2)]

In [None]:
def show_transect():
    beamdata = load_beam_array_ncds(chosen_granule, beam)
    print(beamdata.dtype.metadata)
    # convert the numpy array into a dataframe
    raw_data = icesat_bathymetry.add_along_track_dist(beamdata)
    points_after_filtering = icesat_bathymetry._filter_points(
        raw_data,
        low_limit=-30,
        high_limit=2,
        rolling_window=200,
        max_sea_surf_elev=2,
        filter_below_z=-60,
        filter_below_depth=-60,
        n=1,
        max_geoid_high_z=5,
    )
    raw_data_plot = figure(
        tools=TOOLS,
        sizing_mode="scale_width",
        height=200,
        title="Raw Photons on transect",
    )
    raw_data["oc_sig_conf"] = raw_data.oc_sig_conf.astype("str")
    signal_conf_cmap = factor_cmap(
        "oc_sig_conf",
        palette=Spectral5,
        factors=sorted(raw_data.oc_sig_conf.unique().astype("str")),
    )
    raw_data_plot.scatter(
        source=raw_data,
        x="dist_or",
        y="Z_geoid",
        color=signal_conf_cmap,
        legend_field="oc_sig_conf",
    )

    show(raw_data_plot)
    filtered_plot = figure(
        tools=TOOLS,
        sizing_mode="scale_width",
        height=200,
        title="Points After Filtering",
    )

    filtered_plot.scatter(source=raw_data, x="dist_or", y="Z_geoid", alpha=0.1)
    filtered_plot.scatter(
        source=points_after_filtering, x="dist_or", y="Z_geoid", color="red", alpha=0.5
    )
    filtered_plot.line(
        source=points_after_filtering, x="dist_or", y="sea_level_interp", color="orange"
    )
    filtered_plot.line(
        source=points_after_filtering, x="dist_or", y="gebco_elev", color="blue"
    )
    show(filtered_plot)
    bathy_df = icesat_bathymetry.add_rolling_kde(
        points_after_filtering, window=100, window_meters=None, min_photons=None
    )
    # bathy_df = points_after_filtering.assign(z_kde=points_after_filtering.Z_g.rolling(window=200,center=True).apply(spgaussfilter,raw=True))

    thresholdval = max(bathy_df.kde_val.median() - 0 * bathy_df.kde_val.std(), 0.1)
    print(thresholdval)
    bathy_df.loc[(bathy_df.kde_val < thresholdval), "z_kde"] = np.NaN
    kde_seafloor_plot = figure(
        tools=TOOLS,
        sizing_mode="scale_width",
        height=200,
        title="Assumed Seafloor Pre-correction",
    )

    kde_seafloor_plot.scatter(source=bathy_df, x="dist_or", y="Z_geoid", alpha=0.5)
    kde_seafloor_plot.line(source=bathy_df, x="dist_or", y="z_kde", color="red")
    # kde_seafloor_plot.line(source=bathy_df.eval('kde_val = kde_val*100'),x='dist_or',y='kde_val',color='red')
    show(kde_seafloor_plot)
    true_bathy = query_raster(
        bathy_df, src="../data/test_sites/niiahu/in-situ-DEM/truth.vrt"
    )
    bathy_df = bathy_df.assign(fema_elev=true_bathy, error=bathy_df.z_kde - true_bathy)
    truth_comp_plot = figure(
        tools=TOOLS,
        sizing_mode="scale_width",
        height=200,
        title="Seafloor calculated from ICESat using KDE",
    )

    truth_comp_plot.line(source=bathy_df, x="dist_or", y="z_kde")
    truth_comp_plot.line(source=bathy_df, x="dist_or", y="fema_elev", color="red")

    show(truth_comp_plot)
    bathy_df.plot.scatter(
        x="fema_elev",
        y="error",
        xlim=[0, -12],
        figsize=(15, 15),
        c="kde_val",
        cmap="viridis",
    )
    bathy_df.plot.scatter(
        x="fema_elev",
        y="z_kde",
        xlim=[0, -12],
        ylim=[0, -12],
        figsize=(15, 15),
        c="kde_val",
        cmap="viridis",
    )
    fig = bathy_df.plot.scatter(x="error", y="kde_val")
    fig.axhline(bathy_df.kde_val.median(), color="red")
    fig.axhline(thresholdval, color="green")
    return bathy_df, raw_data

In [None]:
# these callbacks call one other another
def list_netcdf_granules(foldername):
    granlist = list(glob(f"../data/test_sites/{foldername}/ATL03/*.nc"))
    global site_granules
    site_granules = granlist
    global granchooser
    granchooser = ipywidgets.Select(options=granlist)
    print(foldername, "has been selected")
    ipywidgets.interact(set_granule_file, file=granchooser)
    return None


def set_granule_file(file):
    global chosen_granule
    chosen_granule = file
    beamlist = get_beams(chosen_granule)
    global beamchooser
    beamchooser = ipywidgets.Select(options=beamlist)
    print(f"Granule {chosen_granule} selected")
    ipywidgets.interact(set_chosen_beam, beamname=beamchooser)
    return None


def set_chosen_beam(beamname):
    global beam
    beam = beamname
    return None

In [None]:
ipywidgets.interact(list_netcdf_granules, foldername=sitechooser)

In [None]:
df_out, raw_df = show_transect()

In [None]:
df_out.assign(errorabs=abs(df_out.error)).corr()["errorabs"].sort_values()

In [None]:
print(df_out.error.pow(2).mean() ** (0.5))
print(df_out.error.abs().mean())

In [None]:
dac_tide_msl = (
    df_out.z_kde - df_out.sea_level_interp - df_out.dac_corr - df_out.tide_ocean_corr
)
print(
    df_out.assign(dac_tide_msl=dac_tide_msl)
    .eval("error=(dac_tide_msl-fema_elev)**2")
    .error.mean()
    ** (0.5)
)
print(
    df_out.assign(dac_tide_msl=dac_tide_msl)
    .eval("error=dac_tide_msl-fema_elev")
    .error.abs()
    .mean()
)
# df_out.assign(dac_tide_msl=dac_tide_msl).eval('error=dac_tide_msl-fema_elev').assign(errorabs = abs(df_out.error)).corr()['errorabs'].sort_values()

In [None]:
df_out.loc[
    :,
    [
        "p_vec_az",
        "p_vec_elev",
        "easting_corr",
        "northing_corr",
        "Z_geoid",
        "Z_refr",
        "sea_level_interp",
    ],
]