In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from functools import reduce
from pathlib import Path

import numpy as np
import xarray as xr

from pyplume import utils, plotting, constants
from pyplume.dataloaders import load_geo_points
from pyplume.postprocess import ParticleResult

In [None]:
result_dirs = {
    "Kay HFR": "results/tijuana_hurrkay_2023-06-30T14-59-51/simulation_hfrnet uwls_2023-06-30T14-59-54",
    "Kay OI": "results/tijuana_hurrkay_2023-06-30T14-59-51/simulation_hfrnet oi_2023-06-30T14-59-55",
    "Henri HFR": "results/ny_hurrhenri_2023-06-30T16-22-57/simulation_hfrnet_2023-06-30T16-23-04",
    "Henri HYCOM": "results/ny_hurrhenri_2023-06-30T16-22-57/simulation_hycom_2023-06-30T16-23-04",
}
to_compare = ["Henri HFR", "Henri HYCOM"]
# to_compare = ["Kay HFR", "Kay OI"]
coast_lats, coast_lons = None, None
# coast_lats, coast_lons = load_geo_points("data/coastOR2Mex_tijuana.mat")
results = {}
fields = {}

# domain = constants.TIJUANA_RIVER_DOMAIN
domain = {"S": 40.2, "N": 40.7, "W": -74.1, "E": -73.2}
# domain = {
#     "S": 38.5,
#     "N": 39.2,
#     "W": -75,
#     "E": -74
# }

res_path = Path(f"results/compare_density_{'_'.join(to_compare)}")
res_path.mkdir(exist_ok=True)

for name, path in result_dirs.items():
    resdir = Path(path)
    results[name] = ParticleResult(resdir / "particlefile.nc")
    fields[name] = xr.open_dataset(resdir / "ocean_dataset_modified.nc")

In [None]:
# common_timestamps = np.intersect1d(results[to_compare[0]].data_vars["time"], results[to_compare[1]].data_vars["time"])
if len(to_compare) <= 1:
    tstamps = results[to_compare[0]].data_vars["time"]
    tstamps = tstamps[~np.isnan(tstamps)]
    common_timestamps = np.sort(np.unique(tstamps))
else:
    common_timestamps = reduce(
        np.intersect1d, list(map(lambda x: results[x].data_vars["time"], to_compare))
    )

In [None]:
for i, t in enumerate(common_timestamps):
    c0 = to_compare[0]
    c1 = to_compare[1]
    lats1, lons1 = results[c0].get_positions_time(t, query="before")
    lats2, lons2 = results[c1].get_positions_time(t, query="before")
    fig, (ax1, ax2) = plotting.carree_subplots((1, 2), domain=domain, land=coast_lats is None)
    plotting.plot_vectorfield(fields[c0], show_time="average", ax=ax1, color_speed=False)
    plotting.plot_particle_density(
        lats1,
        lons1,
        bins=100,
        ax=ax1,
        pmax=1,
        title=f"{c0} cumulative density at\n{t.astype('datetime64[s]')}",
    )
    plotting.plot_vectorfield(fields[c1], show_time="average", ax=ax2, color_speed=False)
    plotting.plot_particle_density(
        lats2,
        lons2,
        bins=100,
        ax=ax2,
        pmax=1,
        title=f"{c1} cumulative density at\n{t.astype('datetime64[s]')}",
    )
    if coast_lats is not None:
        plotting.plot_coastline(coast_lats, coast_lons, ax=ax1, c="k")
        plotting.plot_coastline(coast_lats, coast_lons, ax=ax2, c="k")
    strnum = str(i).zfill(3)
    plotting.draw_plt(
        savefile=res_path / f"plot_{strnum}", fig=fig, fit=True, figsize=(12, 6)
    )

In [None]:
# save plot as final timestamps

t = common_timestamps[-1]
for i, name in enumerate(to_compare):
    starts = np.unique(
        np.array([results[name].ds["lat"][:, 0], results[name].ds["lon"][:, 0]]).T,
        axis=0,
    ).T
    lats, lons = results[name].get_positions_time(t, query="before")
    fig, ax = plotting.carree_subplots((1, 1), domain=domain, land=True)
    # plotting.plot_vectorfield(fields[name], show_time="average", ax=ax, color_speed=False)
    plotting.plot_particle_density(
        lats,
        lons,
        bins=50,
        ax=ax,
        pmax=0.76,
        title=f"{name} cumulative density at\n{t.astype('datetime64[s]')}",
    )
    # plotting.plot_vectorfield(fields[name], show_time="average", ax=ax, color_speed=False, titlestr=False)
    # plotting.plot_coastline(coast_lats, coast_lons, ax=ax, c="k")
    ax.scatter(starts[1], starts[0], marker="x", c="r")
    plotting.draw_plt(
        savefile=res_path / f"density_plot_final_{name}.png",
        fig=fig,
        fit=True,
        figsize=(7, 4.5),
    )