In [None]:
import xarray as xr
import numpy as np
import pyfesom2 as pf
import matplotlib.tri as mtri
import matplotlib as mpl
import matplotlib.colors as clr
from matplotlib.colors import LinearSegmentedColormap
import cmocean

In [None]:
datapath='/PATH/TO/DATA/'
meshpath='/PATH/TO/MESH/DATA/'
savepath='/PATH/TO/OUTPUT/'

In [None]:
mesh=pf.load_mesh(meshpath)
so3a=xr.open_dataset(meshpath+'fesom.mesh.diag.nc').nod_area[0,:]

In [None]:
#rossby radius
ross=xr.open_dataset(datapath+'rossrad_miss_025.nc')

In [None]:
#approximate resolution

In [None]:
def haversine(lon1, lat1, lon2, lat2):
    lon1,lon2,lat1,lat2=np.radians(lon1),np.radians(lon2),np.radians(lat1),np.radians(lat2)

    dlon = lon2 - lon1 
    dlat = lat2 - lat1 
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a)) 
    r = 6371 
    return c * r 

In [None]:
le1=haversine(so3a.lon[so3a.elements[:,:][0]-1],so3a.lat[so3a.elements[:,:][0]-1],so3a.lon[so3a.elements[:,:][1]-1],so3a.lat[so3a.elements[:,:][1]-1])
le2=haversine(so3a.lon[so3a.elements[:,:][0]-1],so3a.lat[so3a.elements[:,:][0]-1],so3a.lon[so3a.elements[:,:][2]-1],so3a.lat[so3a.elements[:,:][2]-1])
le3=haversine(so3a.lon[so3a.elements[:,:][1]-1],so3a.lat[so3a.elements[:,:][1]-1],so3a.lon[so3a.elements[:,:][2]-1],so3a.lat[so3a.elements[:,:][2]-1])
lm=np.mean([le1,le2,le3],axis=0)
h=(so3a.elem_area[:]*2)/lm

In [None]:
model_lons=mesh.x2
model_lats=mesh.y2
elements=mesh.elem.astype('int32')

d = model_lons[elements].max(axis=1) - model_lons[elements].min(axis=1)
no_cyclic_elem = np.argwhere(d < 100).ravel()

In [None]:
#regular grid for interpolation
dx=0.25
dy=0.25
left,right=-179.875,180
bottom,top=-89.875,90
nx2=left-right #just to make sure the grid isn't square (makes it easier to keep track of dimensions)
ny2=top-bottom
lon_eq = np.arange(left, right, dx) 
lat_eq = np.arange(bottom, top, dy)
nx=lon_eq.shape[0]
ny=lat_eq.shape[0]

xx_eq, yy_eq = np.meshgrid(lon_eq, lat_eq)
xx_eq=xx_eq.T;
yy_eq=yy_eq.T;

In [None]:
triang = mtri.Triangulation(model_lons, model_lats, elements[no_cyclic_elem])
tri = triang.get_trifinder()

In [None]:
regout=mtri.LinearTriInterpolator(triang, h,trifinder=tri)(ross.lon.T, ross.lat.T)

In [None]:
#res/rossby radius
gs=regout.T/ross.var100[0,:].values

In [None]:
#plotting

In [None]:
import matplotlib as mpl
from cmocean import cm as cmo
from matplotlib import ticker
import matplotlib.path as mpath
import cartopy.crs as ccrs
from matplotlib import pyplot as plt

from pyfesom2.ut import cut_region, get_cmap, get_no_cyclic, mask_ne, vec_rotate_r2g

In [None]:
labels=['a','b','c','d','e','f']

In [None]:
#create color bars

In [None]:
# colormap
rrcml=['#fff7fb',
'#ece2f0',
'#d0d1e6',
'#a6bddb',
'#67a9cf',
'#3690c0',
'#02818a',
'#016c59',
'#014636',]

In [None]:
wbg4=LinearSegmentedColormap.from_list('mycmap', rrcml[2:-1],N=6)
wbg4.set_over(mpl.colors.to_rgba(rrcml[-1]))
wbg4.set_under(mpl.colors.to_rgba(rrcml[1]))

In [None]:
def create_proj_figure(mapproj, rowscol, figsize,dpi):
    """ Create figure and axis with cartopy projection.
    Parameters
    ----------
    mapproj: str
        name of the projection:
            merc: Mercator
            pc: PlateCarree (default)
            np: NorthPolarStereo
            sp: SouthPolarStereo
            rob: Robinson
    rowcol: (int, int)
        number of rows and columns of the figure.
    figsize: (float, float)
        width, height in inches.
    Returns
    -------
    fig, ax
    """
    if mapproj == "merc":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.Mercator()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "pc":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.PlateCarree()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "np":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.NorthPolarStereo()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi,
            # facecolor='lightgrey',
        )
    elif mapproj == "sp":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.SouthPolarStereo()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi,
            # facecolor='lightgrey',
        )
    elif mapproj == "rob":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.Robinson()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    else:
        raise ValueError(f"Projection {mapproj} is not supported.")
    return fig, ax

#changed to center the levels at 0
def get_plot_levels(levels, data, lev_to_data=False):
    """Returns levels for the plot.
    Parameters
    ----------
    levels: list, numpy array
        Can be list or numpy array with three or more elements.
        If only three elements provided, they will b einterpereted as min, max, number of levels.
        If more elements provided, they will be used directly.
    data: numpy array of xarray
        Data, that should be plotted with this levels.
    lev_to_data: bool
        Switch to correct the levels to the actual data range.
        This is needed for safe plotting on triangular grid with cartopy.
    Returns
    -------
    data_levels: numpy array
        resulted levels.
    """
    if levels is not None:
        if len(levels) == 3:
            mmin, mmax, nnum = levels
            if lev_to_data:
                mmin, mmax = levels_to_data(mmin, mmax, data)
            nnum = int(nnum)
            data_levels = np.linspace(mmin, mmax, nnum)
        elif len(levels) < 3:
            raise ValueError(
                "Levels can be the list or numpy array with three or more elements."
            )
        else:
            data_levels = np.array(levels)
    else:
        mmin = np.nanmin(data)
        mmax = np.nanmax(data)
        nnum = 40
        data_levels = np.linspace(mmin, mmax, nnum)
    return data_levels

def levels_to_data(mmin, mmax, data):
    """Correct the levels to the actual data range.
    This is needed to make cartopy happy.
    Cartopy can't plot on triangular mesh when the color
    range is larger than the data range.
    """
    # this is needed to make cartopy happy
    mmin_d = -np.nanmax(abs(data))
    mmax_d = np.nanmax(abs(data))
    if mmin < mmin_d:
        mmin = mmin_d
        print("minimum level changed to make cartopy happy")
    if mmax > mmax_d:
        mmax = mmax_d
        print("maximum level changed to make cartopy happy")
    return mmin, mmax

def btplot2(
    data,
    cmap=None,
    box=[-180, 180, -90, 90],
    mapproj="pc",
    levels=None,
    ptype="cf",
    units=r"$^\circ$C",
    figsize=(10, 10),
    rowscol=(1, 1),
    titles=None,
    lw=0.01,
    fontsize=12,
    box_expand=1,
    dpi=100,
    tickdist=10,
    ext='both',
    cbar_or='horizontal',
    contour_data=None,
    ygridlocs=[-75,-60,-45],
    ticklab=None,
):
    sfmt = ticker.ScalarFormatter(useMathText=True)
    sfmt.set_powerlimits((-3, 4))
    """Plots original field on the cartopy map using tricontourf or tripcolor.
    Parameters
    ----------
    mesh: mesh object
        FESOM2 mesh object
    data: np.array or list of np.arrays
        FESOM 2 data on nodes
        (for u,v,u_ice and v_ice one have to first interpolate
        from elements to nodes (`tonodes` function)).
        Can be ether one np.ndarray or list of np.ndarrays.
    cmap: str
        Name of the colormap from cmocean package or from the
        standard matplotlib set.
        By default `Spectral_r` will be used.
    box: list
        Map boundaries in -180 180 -90 90 format that will be used for data
        selection and plotting (default [-180 180 -89 90]).
    mapproj: str
        Map projection. Options are Mercator (merc), Plate Carree (pc),
        North Polar Stereo (np), South Polar Stereo (sp),  Robinson (rob)
    levels: list
        Levels for contour plot in format (min, max, numberOfLevels). List with more than
        3 values will be interpreted as just a list of individual level values.
        If not provided min/max values from data will be used with 40 levels.
    ptype: str
        Plot type. Options are tricontourf (\'cf\') and tripcolor (\'tri\')
    units: str
        Units for color bar.
    figsize: tuple
        figure size in inches
    rowscol: tuple
        number of rows and columns.
    titles: str or list
        Title of the plot (if string) or subplots (if list of strings)
    fontsize: float
        Font size of some of the plot elements.
    box_expand: float
        How much bigger the selected part of the mesh should be
        compared to the `box` to avoid white boundaries.
        Value is in degreed and default is 1.
    """

    if not isinstance(data, list):
        data = [data]
    if titles:
        if not isinstance(titles, list):
            titles = [titles]
        if len(titles) != len(data):
            raise ValueError(
                "The number of titles do not match the number of data fields, please adjust titles (or put to None)"
            )

    if (rowscol[0] * rowscol[1]) < len(data):
        raise ValueError(
            "Number of rows*columns is smaller than number of data fields, please adjust rowscol."
        )

    colormap = get_cmap(cmap=cmap)
    box_mesh = [box[0] - 1, box[1] + 1, box[2] - 1, box[3] + 1]

    fig, ax = create_proj_figure(mapproj, rowscol, figsize,dpi=dpi)
    if isinstance(ax, np.ndarray):
        ax = ax.flatten()
    else:
        ax = [ax]
    
    for ind,data_to_plot in enumerate(data):
        levelstemp=levels[ind]
        tickdisttemp=tickdist[ind]
        data_levels = get_plot_levels(levelstemp, data_to_plot, lev_to_data=True)

        ax[ind].set_extent(box, crs=ccrs.PlateCarree())

        data_to_plot=data_to_plot[360+(box[2]*4):360+(box[3]*4),720+(box[0]*4):720+(box[1]*4)]

        #convert latlon 
        newlon=xx_eq.T[360+(box[2]*4):360+(box[3]*4),720+(box[0]*4):720+(box[1]*4)]
        newlat=yy_eq.T[360+(box[2]*4):360+(box[3]*4),720+(box[0]*4):720+(box[1]*4)]
        lon_eq=ross.lon.T[:,0]
        lat_eq=ross.lat.T[0]

        if ptype == "tri":
            raise ValueError(
            "Not supported."
            )
        elif ptype == "cf":
            image = ax[ind].pcolormesh(
                lon_eq[720+(box[0]*4):720+(box[1]*4)],
                lat_eq[360+(box[2]*4):360+(box[3]*4)],
                data_to_plot,
                vmin=levelstemp[0],
                vmax=levelstemp[1],
                transform=ccrs.PlateCarree(),
                cmap=colormap,
            )
        else:
            raise ValueError(
                "Only `cf` (contourf) and `tri` (tripcolor) options are supported."
            )

        if np.asarray(contour_data).any()!=None:
            c=ax[ind].contour(  lon_eq[720+(box[0]*4):720+(box[1]*4)],
                                lat_eq[360+(box[2]*4):360+(box[3]*4)],
                              np.asarray(contour_data)[ind][0][360+(box[2]*4):360+(box[3]*4),720+(box[0]*4):720+(box[1]*4)],
                                levels=[15],
                                transform=ccrs.PlateCarree(),
                              colors='white',
                                linestyles='dashed',
                              linewidths=1)
            c2=ax[ind].contour(
                                lon_eq[720+(box[0]*4):720+(box[1]*4)],
                                lat_eq[360+(box[2]*4):360+(box[3]*4)],
                                np.asarray(contour_data)[ind][1][360+(box[2]*4):360+(box[3]*4),720+(box[0]*4):720+(box[1]*4)],
                                levels=[15],
                                transform=ccrs.PlateCarree(),
                                colors='black',
                                linestyles='solid',
                              linewidths=1)
        if mapproj == 'np' or mapproj =='sp':

            theta = np.linspace(0, 2*np.pi, 100)
            center, radius = [0.5, 0.5], 0.5
            verts = np.vstack([np.sin(theta), np.cos(theta)]).T
            circle = mpath.Path(verts * radius + center)

            ax[ind].set_boundary(circle, transform=ax[ind].transAxes)
            ax[ind].gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=0.5, \
                    xlocs=range(-180,171,45), ylocs=[], \
                    color='gray', alpha=0.5, linestyle='--', zorder=10)

            ax[ind].gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=0.5, \
                            xlocs=[], ylocs=ygridlocs, \
                            color='gray', alpha=0.5, linestyle='--', zorder=10)


        ax[ind].coastlines(lw=0.5, resolution="110m", facecolor='grey')

        if titles:
            titles = titles.copy()
            ax[ind].set_title(titles.pop(0), size=20)
        tick_array=np.arange(levelstemp[0],levelstemp[1]+1,tickdisttemp)
        cb = fig.colorbar(
            image, orientation=cbar_or, ax=ax[ind], pad=0.01, shrink=0.8, ticks=tick_array,extend=ext, format=sfmt
        )
        if ticklab[ind] != None:
            cb.ax.set_xticklabels(ticklab[ind])
        cb.ax.tick_params(labelsize=fontsize)

        if units:
            cb.set_label(units[ind], size=fontsize)
        else:
            pass

    for delind in range(ind + 1, len(ax)):
        fig.delaxes(ax[delind])

    return fig,ax

In [None]:
fig,ax=btplot2(data=[regout.T/1000,np.log2(gs)],
       figsize=(11,6.5),
       rowscol=(1,2),
       levels=[[3,27,6],[-2,1,1]],
       mapproj='sp',
       tickdist=[4,2],
       units=['Grid spacing (km)','Grid spacing as a multiple of\nthe local Rossby Radius (R)'],
       cmap=wbg4.reversed(),
       box=[-180,180,-90,-25],
       dpi=500,
       ptype='cf',
      ygridlocs=[-75,-60,-45,-30],
      ticklab=[None,['1/4R','1/2R','1R','2R']],
        fontsize=18)

for num,axis in enumerate(ax):
    axis.annotate(labels[num], xy=(0.05, 0.9),xycoords='axes fraction',horizontalalignment='left', 
                     verticalalignment='bottom',fontsize=22,weight='bold')
plt.savefig(savepath+'Supplementary_fig1.png',bbox_inches='tight')