In [1]:
import geoviews as gv
import geoviews.tile_sources as gvt
import xarray as xr
import holoviews as hv
import panel as pn
import hvplot.xarray
import pandas as pd
import numpy as np
from panel.widgets import MultiChoice, Select
from holoviews.operation.datashader import datashade, rasterize
from holoviews.util.transform import easting_northing_to_lon_lat as to_lon_lat
from bokeh.models import HoverTool
from plotly.graph_objects import Figure, Legend
from math import sqrt
import param
from typing import List

pn.extension(
    "plotly",
    sizing_mode="stretch_width",
)

In [2]:
def closest_station(station_ids, latitudes, longitudes, lat, lon):
    min_distance = None
    closest_station_id = None
    for i in range(len(station_ids)):
        distance = sqrt((lat - latitudes[i]) ** 2 + (lon - longitudes[i]) ** 2)
        if min_distance is None or distance < min_distance:
            min_distance = distance
            closest_station_id = station_ids[i]
    return closest_station_id

In [3]:
class SnowExplorer(param.Parameterized):
    def get_cmap(self, score_selected: str):
        if score_selected in ["pbias", "bias"]:
            return "RdBu"
        else:
            return "YlGnBu"

    def get_clims(
        self, var_selected: str, exp_selected: List[str], score_selected: str
    ):
        if score_selected in ["pbias", "bias"]:
            abs_p95 = (
                np.fabs(
                    self.score_ds.sel(exp=exp_selected, variable=var_selected)[
                        score_selected
                    ]
                ).quantile(0.95)
            ).values.item()
            return (-abs_p95, abs_p95)
        if score_selected in ["rmse", "stde", "nrmse", "moyobs", "moymod"]:
            return (
                0,
                (
                    self.score_ds.sel(exp=exp_selected, variable=var_selected)[
                        score_selected
                    ].quantile(0.95)
                ).values.item(),
            )
        if score_selected in ["cor", "nse", "kge12", "kge09"]:
            return (0, 1.0)

    def get_vv_pair(self, val: xr.DataArray) -> tuple[float]:
        vv = np.min([(np.fabs(val).quantile(0.95)).data.item(), 100.0])
        return (-vv, vv)

    def get_rel(self, var_selected: str, exp_selected: List[str], score_selected: str):
        if score_selected in ["bias"]:
            rel = (
                -(
                    np.fabs(
                        self.score_ds.sel(exp=exp_selected[1], variable=var_selected)[
                            score_selected
                        ]
                    )
                    - np.fabs(
                        self.score_ds.sel(exp=exp_selected[0], variable=var_selected)[
                            score_selected
                        ]
                    )
                )
                / np.fabs(
                    self.score_ds.sel(exp=exp_selected[0], variable=var_selected).rmse
                )
            ) * 100.0
            return rel
        elif score_selected in ["pbias"]:
            rel = -(
                np.fabs(
                    self.score_ds.sel(exp=exp_selected[1], variable=var_selected)[
                        score_selected
                    ]
                )
                - np.fabs(
                    self.score_ds.sel(exp=exp_selected[0], variable=var_selected)[
                        score_selected
                    ]
                )
            )
            return rel

        elif score_selected in [
            "rmse",
            "stde",
            "nrmse",
            "cor",
            "nse",
            "kge12",
            "kge09",
        ]:
            rel = (
                -(
                    (
                        self.score_ds.sel(exp=exp_selected[1], variable=var_selected)[
                            score_selected
                        ]
                        - self.score_ds.sel(exp=exp_selected[0], variable=var_selected)[
                            score_selected
                        ]
                    )
                    / self.score_ds.sel(exp=exp_selected[0], variable=var_selected)[
                        score_selected
                    ]
                )
                * 100.0
            )
            return rel

    def get_clims_2(
        self, var_selected: str, exp_selected: List[str], score_selected: str
    ):
        if score_selected in [
            "bias",
            "pbias",
            "rmse",
            "stde",
            "nrmse",
            "cor",
            "nse",
            "kge12",
            "kge09",
        ]:
            rel = self.get_rel(var_selected, exp_selected, score_selected)
            return self.get_vv_pair(rel)
        else:
            return tuple(-1.0, 1.0)

    def get_map_pane(
        self, var_selected: str, exp_selected: List[str], score_selected: str
    ):
        match len(exp_selected):
            case 1:
                df_temp = (
                    self.score_ds.isel(time=0)
                    .sel(variable=var_selected, exp=exp_selected[0])
                    .to_pandas()
                    .reset_index()
                )
                df_temp["marker"] = np.where(df_temp.type_mes < 2.0, "square", "circle")
                df_temp = df_temp[df_temp[score_selected].notnull()]
                return gv.tile_sources.CartoDark() * gv.Points(
                    df_temp,
                    ["lon", "lat"],
                    [score_selected, "station_id", "station_name", "marker"],
                ).options(
                    title=f"{var_selected} - {exp_selected} - {score_selected}",
                    xlabel=None,
                    height=500,
                    width=800,
                    color=hv.dim(score_selected),
                    size=10,
                    marker=hv.dim("marker"),
                    colorbar=True,
                    cmap=self.get_cmap(score_selected),
                    clim=self.get_clims(var_selected, exp_selected, score_selected),
                    clipping_colors={"min": "red"},
                    tools=[
                        HoverTool(
                            tooltips=[
                                ("Station ID", "@station_id"),
                                ("Station Name", "@station_name"),
                                (f"{score_selected}", f"@{score_selected}"),
                            ]
                        )
                    ],
                )
            case 2:
                lon = self.score_ds["lon"].values
                lat = self.score_ds["lat"].values
                sta_name = self.score_ds["station_name"].values
                rmse_exp1 = self.score_ds.isel(time=0).sel(
                    variable=var_selected, exp=exp_selected[0]
                )[score_selected]
                rmse_exp2 = self.score_ds.isel(time=0).sel(
                    variable=var_selected, exp=exp_selected[1]
                )[score_selected]
                # rmse_diff = rmse_exp1 - rmse_exp2
                rmse_diff = self.get_rel(var_selected, exp_selected, score_selected)
                station_id = rmse_diff.station_id.values
                type_mes = (
                    self.score_ds.isel(time=0)
                    .sel(variable=var_selected, exp=exp_selected[0])
                    .type_mes.values
                )
                df = (
                    pd.DataFrame(
                        {
                            "lon": lon,
                            "lat": lat,
                            "station_name": sta_name,
                            "station_id": station_id,
                            "type_mes": type_mes,
                            score_selected: rmse_diff.isel(time=0).values,
                        }
                    )
                    .reset_index()
                    .dropna()
                )
                df["marker"] = np.where(df.type_mes < 2.0, "square", "circle")
                return gv.tile_sources.CartoDark() * gv.Points(
                    df,
                    ["lon", "lat"],
                    [score_selected, "station_id", "station_name", "marker"],
                ).options(
                    title=f"{var_selected} - {exp_selected} - {score_selected}",
                    xlabel=None,
                    height=500,
                    width=800,
                    color=hv.dim(score_selected),
                    size=10,
                    marker=hv.dim("marker"),
                    colorbar=True,
                    cmap="RdBu_r",
                    clim=self.get_clims_2(var_selected, exp_selected, score_selected),
                    clipping_colors={"min": "red"},
                    tools=[
                        HoverTool(
                            tooltips=[
                                ("Station ID", "@station_id"),
                                ("Station Name", "@station_name"),
                                (f"{score_selected}", f"@{score_selected}"),
                            ]
                        )
                    ],
                )
            case other:
                print("somet'rong hoss")

    def hide_index(self, plot, element):
        plot.handles["table"].index_position = None

    def tap_series(self, x, y):
        lon_lat = to_lon_lat(x, y)
        self.nearest_station = closest_station(
            self.LIST_STATION_IDS_PLOT,
            self.LIST_LATS_PLOT,
            self.LIST_LONS_PLOT,
            lon_lat[1],
            lon_lat[0],
        )
        scatter = hv.Scatter(
            self.series_ds.sel(station_id=self.nearest_station, exp=["Obs"]),
            kdims=["time"],
            vdims=[self.var_select.value],
            label="Observation",
        ).opts(color="black", size=10)
        curve = hv.Curve(
            self.series_ds.sel(
                station_id=self.nearest_station, exp=[self.exp_select.value[0]]
            ),
            kdims=["time"],
            vdims=[self.var_select.value],
            label=self.exp_select.value[0],
        )
        if len(self.exp_select.value) == 2:
            curve_2 = hv.Curve(
                self.series_ds.sel(
                    station_id=self.nearest_station, exp=[self.exp_select.value[1]]
                ),
                kdims=["time"],
                vdims=[self.var_select.value],
                label=self.exp_select.value[1],
            )
            new_title = f"{self.exp_select.value[0]} and {self.exp_select.value[1]} - {self.var_select.value} for {self.nearest_station} at [{lon_lat[0]:.3f},{lon_lat[1]:.3f}]"

            self.plotly_row.clear()
            self.plotly_row.append(
                Figure(
                    hvplot.render(
                        (scatter * curve * curve_2).opts(
                            width=800,
                            tools=["hover"],
                            title=new_title,
                            xlabel=None,
                        ),
                        backend="plotly",
                    )
                ).update_layout(
                    title=new_title,
                    hovermode="x",
                    template="plotly_dark",
                    autosize=True,
                    width=None,
                    height=None,
                    showlegend=True,
                    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
                )
            )

            return hv.Div(
                """<hr style="border-top: 1px transparent; height: 1px;">"""
            ).opts(height=5)
        else:
            new_title = f"{self.exp_select.value[0]} - {self.var_select.value} for {self.nearest_station} at [{lon_lat[0]:.3f},{lon_lat[1]:.3f}]"
            self.plotly_row.clear()
            self.plotly_row.append(
                Figure(
                    hvplot.render(
                        (scatter * curve).opts(
                            width=800,
                            tools=["hover"],
                            title=new_title,
                            xlabel=None,
                        ),
                        backend="plotly",
                    )
                ).update_layout(
                    title=new_title,
                    hovermode="x",
                    template="plotly_dark",
                    autosize=True,
                    width=None,
                    height=None,
                    showlegend=True,
                    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
                )
            )
            return hv.Div(
                """<hr style="border-top: 1px transparent; height: 1px;">"""
            ).opts(height=5)

    def __init__(self):
        # ----------------------------- Constants and Init -----------------------------------
        self.nearest_station = "BCE-1C08"
        self.score_file_input = pn.widgets.TextInput(
            name="Score NetCDF Path Input",
            value="/data/score_snw_station_diff_alti_200_HRDPS_CaPA01_CaPA02_period_20191001_20220629.nc",
            # value="/fs/site5/eccc/cmd/s/yor000/Snow_Explorer/score_snw_station_diff_alti_200_HRDPS_CaPA01_CaPA02_period_20191001_20220629.nc",
        )
        self.series_file_input = pn.widgets.TextInput(
            name="Series NetCDF Path Input",
            value="/data/series_snw_station_diff_alti_200_HRDPS_CaPA01_CaPA02_period_20191001_20220629.nc",
            # value="/fs/site5/eccc/cmd/s/yor000/Snow_Explorer/series_snw_station_diff_alti_200_HRDPS_CaPA01_CaPA02_period_20191001_20220629.nc",
        )
        self.load_button = pn.widgets.Button(name="Load")
        self.SET_EXCLUDED_VARS = set(
            ["lat", "lon", "elevation", "source", "station_name", "type_mes"]
        )
        self.SET_EXCLUDED_VARS_PLOT = set(
            [
                "lat",
                "lon",
                "time_stamp",
                "elevation",
                "source",
                "station_name",
                "type_mes",
            ]
        )
        self.score_ds = xr.open_dataset(self.score_file_input.value)
        self.series_ds = xr.open_dataset(self.series_file_input.value)
        self.LIST_STATION_IDS_PLOT = list(self.score_ds.station_id.values)
        self.LIST_LATS_PLOT = list(self.score_ds.lat.values)
        self.LIST_LONS_PLOT = list(self.score_ds.lon.values)
        self.plotly_row = pn.Row()
        # ----------------------------- Widgets -----------------------------------
        self.var_select = pn.widgets.Select(
            name="Variable Select", options=list(self.score_ds.variable.values)
        )
        self.exp_select = pn.widgets.MultiChoice(
            name="Experiments Select",
            options=list(self.score_ds.exp.values),
            value=[list(self.score_ds.exp.values)[0]],
            max_items=2,
        )
        self.score_select = pn.widgets.Select(
            name="Score Select",
            options=list(set(self.score_ds.data_vars) - self.SET_EXCLUDED_VARS_PLOT),
        )
        # ----------------------------- Map and DynamicMap -----------------------------------
        self.map = self.get_map_pane(
            self.var_select.value,
            self.exp_select.value,
            self.score_select.value,
        )
        self.tap_stream = hv.streams.Tap(source=self.map, x=-122.76, y=49.555)
        self.dynamic_map = hv.DynamicMap(self.tap_series, streams=[self.tap_stream])
        # ----------------------------- Update Button -----------------------------------
        self.update_button = pn.widgets.Button(name="Update")

    @pn.depends("load_button.clicks")
    def load_view(self):
        load_row = pn.Row()
        self.score_ds = xr.open_dataset(self.score_file_input.value, engine="h5netcdf")
        self.series_ds = xr.open_dataset(
            self.series_file_input.value, engine="h5netcdf"
        )
        self.var_select.options = list(self.score_ds.variable.values)
        # self.var_select.value = [list(self.score_ds.variable.values)[0]]
        self.exp_select.options = list(self.score_ds.exp.values)
        # self.exp_select.value = [list(self.score_ds.exp.values)[0]]
        self.score_select.options = list(
            set(self.score_ds.data_vars) - self.SET_EXCLUDED_VARS_PLOT
        )
        # self.score_select.value = [list(
        #     set(self.score_ds.data_vars) - self.SET_EXCLUDED_VARS_PLOT
        # )[-1]]
        load_row.append(pn.panel(self.score_ds))
        load_row.append(pn.panel(self.series_ds))
        self.LIST_STATION_IDS_PLOT = list(self.score_ds.station_id.values)
        self.LIST_LATS_PLOT = list(self.score_ds.lat.values)
        self.LIST_LONS_PLOT = list(self.score_ds.lon.values)
        return load_row

    @pn.depends("update_button.clicks")
    def view(self):
        self.map = self.get_map_pane(
            self.var_select.value,
            self.exp_select.value,
            self.score_select.value,
        )
        self.tap_stream = hv.streams.Tap(source=self.map, x=-122.76, y=49.555)
        self.dynamic_map = hv.DynamicMap(self.tap_series, streams=[self.tap_stream])

        return pn.Row(
            pn.Column(self.map, self.dynamic_map),
        )


app = SnowExplorer()

template = pn.template.MaterialTemplate(
    logo="https://www.canada.ca/etc/designs/canada/wet-boew/assets/sig-blk-en.svg",
    site="CMC",
    title="Snow Explorer",
    sidebar=[
        pn.WidgetBox(
            "Score and Series File Inputs",
            app.score_file_input,
            app.series_file_input,
            app.load_button,
        ),
        # app._on_score_file_input,
        app.var_select,
        app.exp_select,
        app.score_select,
        app.update_button,
    ],
    main=[app.view, app.plotly_row, app.load_view],
    sidebar_width=800,
).servable()

In [5]:
# /fs/site6/eccc/mrd/rpnenv/vvi001/eval_snow/RDRS_backext/snow/fic_nc/score_snw_station_diff_alti_200_RDRS-backext-IC401_ERA5L_period_19730101_19730501.nc
# /fs/site6/eccc/mrd/rpnenv/vvi001/eval_snow/RDRS_backext/snow/fic_nc/series_snw_station_diff_alti_200_RDRS-backext-IC401_ERA5L_period_19730101_19730501.nc
# template

In [33]:
# app.get_rel("SML", ['HRDPS', 'CaPA01'], 'rmse')

In [32]:
# var_selected = "SML"
# exp_selected = ['HRDPS', 'CaPA01']
# score_selected = 'rmse'

# lon = app.score_ds["lon"].values
# lat = app.score_ds["lat"].values
# sta_name = app.score_ds["station_name"].values
# rmse_exp1 = app.score_ds.isel(time=0).sel(
#     variable=var_selected, exp=exp_selected[0]
# )[score_selected]
# rmse_exp2 = app.score_ds.isel(time=0).sel(
#     variable=var_selected, exp=exp_selected[1]
# )[score_selected]
# # rmse_diff = rmse_exp1 - rmse_exp2
# rmse_diff = app.get_rel("SML", ['HRDPS', 'CaPA01'], 'rmse')
# display(rmse_diff)
# station_id = rmse_diff.station_id.values
# type_mes = (
#     app.score_ds.isel(time=0)
#     .sel(variable=var_selected, exp=exp_selected[0])
#     .type_mes.values
# )
# df = (
#     pd.DataFrame(
#         {
#             "lon": lon,
#             "lat": lat,
#             "station_name": sta_name,
#             "station_id": station_id,
#             "type_mes": type_mes,
#             score_selected: rmse_diff.isel(time=0).values,
#         }
#     )
#     .reset_index()
#     .dropna()
# )
# df["marker"] = np.where(df.type_mes < 2.0, "square", "circle")
# df