In [None]:
from cf_units import Unit
from IPython.core.display import clear_output
import iris
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import matplotlib.patheffects as PathEffects
import numpy as np
import pandas as pd
from pathlib import Path
import xarray as xr
from tqdm import tqdm_notebook as tqdm

import arke
from arke.cart import lcc_map, lcc_map_grid

from common_defs import winters, nyr, winter_dates, toponyms
from plot_utils import LCC_KW, trans, clev101, abs_plt_kw, iletters
import mypaths

from octant.core import TrackRun, OctantTrack, HOUR
from octant.misc import calc_all_dens, SUBSETS, DENSITY_TYPES
import octant
octant.__version__

In [None]:
import warnings
warnings.filterwarnings('ignore', category=RuntimeWarning, module='dask')
warnings.filterwarnings('ignore', category=UserWarning, module='iris')

In [None]:
plt.style.use('paperfig.mplstyle')

In [None]:
lsm = xr.open_dataarray(mypaths.era5_dir / 'lsm.nc').squeeze()
lsm.attrs['units'] = 1
lon2d, lat2d = np.meshgrid(lsm.longitude, lsm.latitude)

#### Grids and arrays for density calculation

In [None]:
lon_dens1d = np.arange(-20., 50.1, 1)
lat_dens1d = np.arange(65., 85.1, 1)
lon_dens, lat_dens = np.meshgrid(lon_dens1d, lat_dens1d)

In [None]:
lsm_1deg = lsm.interp(coords=dict(longitude=lon_dens[0, :], latitude=lat_dens[:, 0])).to_iris()
lsm_1deg.coord('longitude').units = Unit('degrees_east')
lsm_1deg.coord('latitude').units = Unit('degrees_north')
lsm_1deg.coord('longitude').guess_bounds()
lsm_1deg.coord('latitude').guess_bounds()

In [None]:
weights = lsm_1deg.copy(data=iris.analysis.cartography.area_weights(lsm_1deg, normalize=False))
weights.units = Unit('m^2')
weights.rename('area_weights')
weights.convert_units('km^2')

In [None]:
area_weights_1deg = xr.DataArray.from_iris(weights)

In [None]:
area_weights_1deg_norm = xr.DataArray.from_iris(lsm_1deg.copy(data=iris.analysis.cartography.area_weights(lsm_1deg, normalize=True)))

In [None]:
lsm_1deg = xr.DataArray.from_iris(lsm_1deg)

### Mean sea ice edge position

In [None]:
sea_ice_ds = xr.open_mfdataset(sorted(mypaths.era5_dir.glob('*.ci.nc')))

In [None]:
sea_ice_conc = sea_ice_ds.ci[:, (sea_ice_ds.latitude >= 65) & (sea_ice_ds.latitude <= 85), (sea_ice_ds.longitude >= -20) & (sea_ice_ds.longitude <= 50)]

In [None]:
sic_thresh = 0.15  # 15% threshold

In [None]:
sea_ice_conc_mean = sea_ice_conc.mean(dim='time')

### Load tracks

In [None]:
run_id_start = 0 # 106
run_id = 0
dataset = 'era5'

In [None]:
TR = TrackRun()
TR.data = OctantTrack.from_mux_df(pd.read_parquet(mypaths.procdir / f'{dataset}_run{run_id_start+run_id:03d}_2008_2017.parquet', engine='pyarrow'))
# TR.data = OctantTrack.from_mux_df(pd.read_parquet(mypaths.procdir / f'{dataset}_run{run_id_start+run_id:03d}_2008_2017.parquet', engine='pyarrow'))

In [None]:
TR.size('basic'), TR.size('moderate'), TR.size('strong')

In [None]:
vo_per_track = TR['moderate'].groupby('track_idx').apply(lambda x: x.max_vort)

In [None]:
np.percentile(vo_per_track, 95)

In [None]:
strong = vo_per_track[vo_per_track>np.percentile(vo_per_track, 95)]

In [None]:
strong.size

In [None]:
from octant.core import TrackSettings
ts = TrackSettings(list((mypaths.trackresdir / dataset / f'run{run_id+run_id_start:03d}' / '2008_2009').glob('*.conf'))[0])
ts.zeta_max0

## Calculate density

In [None]:
lon_dens1d = np.arange(-20., 50.1, 1)
lat_dens1d = np.arange(65., 85.1, 1)
lon_dens, lat_dens = np.meshgrid(lon_dens1d, lat_dens1d)
r = 111.3
grid_str = r'$1^\degree\times 1^\degree$'

In [None]:
all_dens = calc_all_dens(TR, lon_dens, lat_dens, r=r)
attrs = all_dens.attrs.copy()
all_dens = all_dens / nyr
all_dens.attrs.update(attrs)
clear_output()

In [None]:
# all_dens.to_netcdf(mypaths.procdir / f'{dataset}_run{run_id_start+run_id:03d}_2008_2017_all_densities_r{round(r):3d}.nc')

In [None]:
AXGR_KW = dict(axes_pad=0.45,
               cbar_location='right',
               cbar_mode='single',
               cbar_pad=0.1,
               cbar_size='3%')
diff_plt_kw = dict(cmap='coolwarm', extend='both', **trans)
cntr_kw = dict(colors='#222222', linewidths=0.5, **trans)
cntr_lab_kw = dict(fmt='%3.0f', colors='k')
ci_kw = dict(levels=[0.15], linewidths=4, **trans)
at_kw = dict(loc=2, prop=dict(size='small'))
text_kw = dict(ha='center',
               fontsize='xx-large',
               path_effects=[PathEffects.withStroke(linewidth=3,
                                                    foreground='w')])

In [None]:
subset = 'moderate'
dens_type = 'track'

fig = plt.figure(figsize=(10, 10))
axgr = lcc_map_grid(fig, (1, 1), **LCC_KW, **AXGR_KW)

ax = axgr.axes_all[0]
cax = axgr.cbar_axes[0]

h = all_dens.sel(subset=subset, dens_type=dens_type).plot.contourf(ax=ax,
                                                                   levels=np.arange(3, 30, 3),
                                                                   add_colorbar=False,
                                                                   add_labels=False,
                                                                   **abs_plt_kw)
cb = fig.colorbar(h, cax=cax)
cb.ax.tick_params(labelsize='x-large')

# sea_ice_conc_mean.plot.contour(ax=ax, add_labels=False, colors='C9', **ci_kw)

# for ax in axgr.axes_all:
#     for topo in toponyms:
#         txt = ax.text(topo['lon'], topo['lat'], topo['name'], **text_kw, **trans)
#         txt.set_zorder(100)
        
ttl = f'Polar low track density\n{dataset.upper()}\nr = {r} km\n2008-2017 (9 winters)'
axgr.axes_all[0].add_artist(AnchoredText(ttl, loc=4, prop=dict(size='large')));

fig.savefig(mypaths.plotdir / f'pmctrack_era5_{subset}_{dens_type}_density_r{round(r):3d}')

In [None]:
from ipywidgets import interact

In [None]:
all_dens_norm = all_dens * (area_weights_1deg/area_weights_1deg.max())

In [None]:
@interact(dens_type=DENSITY_TYPES, subset=SUBSETS)
def fun(dens_type='track', subset='moderate'):
    fig = plt.figure(figsize=(10, 10))
    ax = lcc_map(fig, **LCC_KW)

    h = all_dens_norm.sel(subset=subset, dens_type=dens_type).plot.contourf(add_colorbar=False, **abs_plt_kw)
    cb = fig.colorbar(h, pad=0.01, shrink=0.7)
#     ax.plot(13, 74, marker='o', **mapkey)

# ttl_str = "\n".join([f"{k} = {v}" for k, v in density_kw.items()])
# ax.add_artist(AnchoredText(f'{dens2show.capitalize()} density (per year)\n{dataset}\n{ttl_str}', loc=2))

# for _, tr in TR[subset].groupby('track_idx'):
#     tr.plot_track(ax=ax);

In [None]:
# fig.savefig(mypaths.plotdir / 'climatology' / f'pmctrack_density_point_{dataset}_{density_kw["subset"]}_r{density_kw["r"]:3.0f}.{fmt}', **svfigkw)

In [None]:
# AXGR_KW = dict(axes_pad=0.4,
#                cbar_location='right',
#                cbar_mode='each',
#                cbar_pad=0.05,
#                cbar_size='3%')
AXGR_KW = dict(axes_pad=0.05)
abs_plt_kw = dict(cmap='Oranges', extend='max', **mapkey)
diff_plt_kw = dict(cmap='coolwarm', extend='both', **mapkey)
cntr_kw = dict(colors='#222222', linewidths=0.05, **mapkey)
cntr_lab_kw = dict(fmt='%3.0f', colors='k')
ci_kw = dict(levels=[sic_thresh], linewidths=2, **mapkey)
at_kw = dict(loc=1, prop=dict(size='large'))

In [None]:
dset_names = (
    ('era5_run000', 'ERA5, Run A'),
    ('interim_run106', 'ERA-Interim, Run A'),
    ('interim_run100', 'ERA-Interim, Run B')
)

In [None]:
ncol = 3
nrow = 3

for dens_type in tqdm(density_types, desc='figures', leave=False):
    fig = plt.figure(figsize=(ncol*5, nrow*5))
    axgr = lcc_map_grid(fig, (nrow, ncol), **LCC_KW, **AXGR_KW)
    
    ttl = f'{dens_type.capitalize()} density\nr = {r} km, {grid_str}\n2008-2017 (9 winters)'
    axgr.axes_all[0].set_title(ttl, loc='left', fontsize='x-large')

    ax_labels = iter(string.ascii_lowercase)
    iter_cax = iter(axgr.cbar_axes)
    for axrow, (dset_name, dset_label) in tqdm(zip(axgr.axes_row, dset_names), desc='datasets', leave=False):
        
        all_dens = xr.open_dataarray(mypaths.procdir / f'{dset_name}_2008_2017_all_densities_r{round(r):3d}.nc')
        for ax, subset in tqdm(zip(axrow, subsets), desc='subsets', leave=False):
            data = all_dens.sel(subset=subset, dens_type=dens_type)
            txt = f'({next(ax_labels)}) {dset_label}\n{subset}'
            ax.add_artist(AnchoredText(txt, **at_kw))
            
#             try:
#                 h = data.plot.contourf(ax=ax, robust=True, add_colorbar=False, add_labels=False, **abs_plt_kw)
#             except:
            h = data.plot.contourf(ax=ax, robust=False, add_colorbar=False, add_labels=False, **abs_plt_kw)

#             # levels = h.get_array()
#             hh = ax.contour(lon_dens, lat_dens, ma_data, **cntr_kw)
#             hhh = ax.clabel(hh, **cntr_lab_kw)
#             plt.setp(hhh, path_effects=[PathEffects.withStroke(linewidth=1.5, foreground='w')])
            # Overlay with sea ice edge
            sea_ice_conc_mean.plot.contour(ax=ax, add_labels=False, colors='C0', **ci_kw)
        
            cax = inset_axes(ax, borderpad=0.5,
                     width="4%",
                     height="45%",
                     loc=2)
            
            # cax = next(iter_cax)
            cb = fig.colorbar(h, orientation='vertical', cax=cax)
            cb.ax.tick_params(labelsize='large')
            for i in cb.ax.get_yticklabels():
                i.set_path_effects([PathEffects.withStroke(linewidth=2, foreground='w')])

    fig.savefig(mypaths.plotdir / f'pmctrack_era5_vs_interim_{dens_type}_density_r{round(r):3d}.{fmt}', **svfigkw)
    plt.close()