In [None]:
# from cf_units import Unit
# from IPython.core.display import clear_output
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import matplotlib.patheffects as PathEffects
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np
import xarray as xr
import string

import arke
from arke.cart import lcc_map, lcc_map_grid

from common_defs import winters, nyr, winter_dates, aliases, datasets, period, bbox
from plot_utils import LCC_KW, trans, clev101, abs_plt_kw, iletters, use_style
import mypaths

from octant.core import TrackRun
from octant.decor import get_pbar
from octant.misc import calc_all_dens, DENSITY_TYPES
import octant

octant.__version__

In [None]:
from scipy.ndimage.filters import gaussian_filter

In [None]:
octant.RUNTIME.enable_progress_bar = True

pbar = get_pbar()

In [None]:
from categorise_and_save import get_lsm

In [None]:
# import warnings

# warnings.filterwarnings("ignore", category=RuntimeWarning, module="dask")
# warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

In [None]:
lsm = get_lsm(mypaths.era5_dir / "lsm.nc", bbox=bbox, shift=True)
lon2d, lat2d = np.meshgrid(lsm.longitude, lsm.latitude)

#### Grids and arrays for density calculation

In [None]:
lon_dens1d = np.arange(-20.0, 50.1, 1)  # 0.3)
lat_dens1d = np.arange(70.0, 85.1, 1)  # 0.3)
grid_str = "1deg"
# lon_dens, lat_dens = np.meshgrid(lon_dens1d, lat_dens1d)

### Mean sea ice edge position

In [None]:
# sea_ice_ds = xr.open_mfdataset(sorted(mypaths.era5_dir.glob("*.ci.nc")))

In [None]:
# sea_ice_conc = sea_ice_ds.ci[
#     :,
#     (sea_ice_ds.latitude >= 65) & (sea_ice_ds.latitude <= 85),
#     (sea_ice_ds.longitude >= -20) & (sea_ice_ds.longitude <= 50),
# ]

In [None]:
# sic_thresh = 0.15  # 15% threshold

In [None]:
# sea_ice_conc_mean = sea_ice_conc.mean(dim='time')

## Calculate density

In [None]:
runs2process = dict(era5=[0], interim=[0])
method = "cell"

In [None]:
for dset in pbar(datasets):
    for run_num in pbar(runs2process[dset]):
        track_run = TrackRun.from_archive(mypaths.procdir / f"{dset}_run{run_num:03d}_{period}.h5")

        all_dens = calc_all_dens(track_run, lon_dens1d, lat_dens1d, method=method)
        attrs = all_dens.attrs.copy()
        all_dens = all_dens / nyr
        all_dens.attrs.update(attrs)

    all_dens.to_netcdf(mypaths.procdir / f"all_dens_{dset}_{period}_{method}_{grid_str}.nc")

In [None]:
AXGR_KW = dict(
    axes_pad=0.45)  # , cbar_location="right", cbar_mode="single", cbar_pad=0.1, cbar_size="3%"
# )
diff_plt_kw = dict(cmap="coolwarm", extend="both", **trans)
cntr_kw = dict(colors="#222222", linewidths=0.5, **trans)
cntr_lab_kw = dict(fmt="%3.0f", colors="k")
ci_kw = dict(levels=[0.15], linewidths=4, **trans)
at_kw = dict(loc=2, prop=dict(size="small"))
text_kw = dict(
    ha="center",
    fontsize="xx-large",
    path_effects=[PathEffects.withStroke(linewidth=3, foreground="w")],
)

$g(x)={\frac  {1}{{\sqrt  {2\pi }}\cdot \sigma }}\cdot e^{{-{\frac  {x^{2}}{2\sigma ^{2}}}}}$

### Load tracks

In [None]:
runs2process = dict(era5=[0], interim=[0])
method = "cell"
subsets = ["pmc"]

In [None]:
track_runs = {}
for dset in pbar(datasets):
    track_runs[dset] = {}
    for run_num in pbar(runs2process[dset]):
        track_runs[dset][f"run{run_num:03d}"] = TrackRun.from_archive(
            mypaths.procdir / f"{dset}_run{run_num:03d}_{period}.h5"
        )

In [None]:
kwargs = dict(sigma = (0, 0, 1, 2.5))
smooth_func = gaussian_filter

In [None]:
for dset in pbar(datasets):
    for run_num in pbar(runs2process[dset]):
        all_dens = calc_all_dens(track_runs[dset][f"run{run_num:03d}"], lon_dens1d, lat_dens1d, method=method)
        attrs = all_dens.attrs.copy()
        all_dens /= nyr
        all_dens = xr.apply_ufunc(smooth_func, all_dens, kwargs=kwargs)

        attrs["smooth"] = f"{smooth_func.__module__}.{smooth_func.__name__} with {kwargs}"

        all_dens.attrs.update(attrs)
        all_dens.to_netcdf(
            mypaths.procdir / f"all_dens_{dset}_run{run_num:03d}_{period}_{method}_{grid_str}_smoothed.nc"
        )
# clear_output()

In [None]:
# from ipywidgets import interact
# @interact(dens_type=DENSITY_TYPES, subset=["pmc"], dset=datasets)
# def fun(dset, dens_type="track", subset="moderate"):

#     fig = plt.figure(figsize=(10, 10))
#     ax = lcc_map(fig, **LCC_KW)

#     with xr.open_dataarray(
#         mypaths.procdir
#         / f"all_dens_{dset}_run{run_num:03d}_{period}_{method}_{grid_str}_smoothed.nc"
#     ) as all_dens:

#         h = all_dens.sel(subset=subset, dens_type=dens_type).plot.contourf(
#             add_colorbar=False, ax=ax, **abs_plt_kw
#         )
#         cb = fig.colorbar(h, pad=0.01, shrink=0.7)

In [None]:
def smooth_bell(r, a=250, b=100):
    return np.where(r < a, (a ** 2 - r ** 2) / (a ** 2 + r ** 2 * (a ** 2 / b ** 2 - 1)), 0)

In [None]:
rr = np.arange(0, 500, 10)
plt.plot(rr, smooth_bell(rr));

In [None]:
ncol = len(datasets)
nrow = len(subsets)

for dens_type in pbar(DENSITY_TYPES):
    fig = plt.figure(figsize=(ncol * 5, nrow * 5))
    axgr = lcc_map_grid(fig, (nrow, ncol), **LCC_KW, **AXGR_KW)

    ttl = f"{dens_type.capitalize()} density\n{period.replace('_', '-')} ({nyr} winters)"
    fig.suptitle(
        ttl,
        x=axgr.axes_all[0].get_position().get_points()[0, 0],
        transform=axgr.axes_all[0].transAxes,
        ha="left",
        fontsize="large",
    )

    iletters = iter(string.ascii_lowercase)
    for ax in axgr.axes_all:
        ax.set_title(f"({next(iletters)})", loc="left", fontsize="medium")
    # iter_cax = iter(axgr.cbar_axes)
    for axcol, dset in zip(axgr.axes_column, datasets):

        with xr.open_dataarray(
            mypaths.procdir
            / f"all_dens_{dset}_run{run_num:03d}_{period}_{method}_{grid_str}_smoothed.nc"
        ) as all_dens:
            for ax, subset in zip(axcol, subsets):

                data = all_dens.sel(subset=subset, dens_type=dens_type)
                lab = aliases[dset]  # "\n".join(dset_label.split(", "))
                txt = f"{lab}\n{subset}"
                ax.add_artist(AnchoredText(txt, **at_kw))
                #             try:
                #                 h = data.plot.contourf(ax=ax, robust=True, add_colorbar=False, add_labels=False, **abs_plt_kw)
                #             except:
                h = data.plot.contourf(
                    ax=ax, robust=True, add_colorbar=False, add_labels=False, **abs_plt_kw
                )

                # Overlay with sea ice edge
                # sea_ice_conc_mean.plot.contour(ax=ax, add_labels=False, colors="C0", **ci_kw)

                cax = inset_axes(ax, borderpad=0.5, width="4%", height="45%", loc="upper left")

                # cax = next(iter_cax)
                cb = fig.colorbar(h, orientation="vertical", cax=cax)
                cb.ax.tick_params(labelsize="large")
                for i in cb.ax.get_yticklabels():
                    i.set_path_effects([PathEffects.withStroke(linewidth=2, foreground="w")])
#     fig.savefig(mypaths.plotdir / f'era5_vs_interim_{dens_type}_density_cell_norm_smooth')
#     plt.close()