# Create plots for the analysis of the spatial patterns of EKE
Plots of mean EKE as it is modeled and observed during the observational period (1993-2020) as well as the modeled change in EKE between a historic period (1860-1949) and a projected period (2061-2090) are produced.

Plotting functions are largely adapted from pyfesom2 (https://github.com/FESOM/pyfesom2)

In [None]:
import xarray as xr
import pyfesom2 as pf
import numpy as np
from matplotlib import pyplot as plt

#necessary for plotting 
import math
import os
import sys

import joblib
import matplotlib as mpl
import shapely.vectorized
import xarray as xr
from cmocean import cm as cmo
from matplotlib import cm, ticker
from matplotlib.colors import LinearSegmentedColormap
from netCDF4 import Dataset, MFDataset, num2date
import cartopy.crs as ccrs
import cartopy

from pyfesom2.load_mesh_data import ind_for_depth
from pyfesom2.regridding import fesom2regular, tonodes
from pyfesom2.transect import transect_get_nodes
from pyfesom2.ut import cut_region, get_cmap, get_no_cyclic, mask_ne, vec_rotate_r2g

In [None]:
meshpath='/path/to/mesh/data/'

In [None]:
#get mesh info
mesh=pf.load_mesh(meshpath)
griddes=xr.open_dataset(meshpath+'griddes.nc')

In [None]:
#difference of means data using running mean anoms
ekerm=xr.open_dataset('c6_ssh-eke_mon_lin_eq_5day_node_ensmean_1860-2090_nol_21yr_anom_rename.nc')

rmh=ekerm.EKE[:6570,:].mean(dim='time')
rme=ekerm.EKE[-2190:,:].mean(dim='time')
rmdif=rme-rmh
ekerm.close()
rmh.close()
rme.close()

In [None]:
#observed and modeled EKE during the observational period. Using reference period anomalies.
fobs=xr.open_dataset('obs_eke_1993-2020_nol_5d_rm_mdata.nc')
obsm=fobs.EKE[:,:].mean(dim='time')

fmod=xr.open_dataset('c6_ssh-eke_mon_lin_eq_5day_node_ensmean_1993-2020_nol_obsanom_mdata.nc')
modm=fmod.EKE.mean(dim='time')

Define the plotting functions that need to be modified from pyfesom2.

In [None]:
#added dpi
def create_proj_figure(mapproj, rowscol, figsize,dpi):
    """ Create figure and axis with cartopy projection.
    Parameters
    ----------
    mapproj: str
        name of the projection:
            merc: Mercator
            pc: PlateCarree (default)
            np: NorthPolarStereo
            sp: SouthPolarStereo
            rob: Robinson
    rowcol: (int, int)
        number of rows and columns of the figure.
    figsize: (float, float)
        width, height in inches.
    Returns
    -------
    fig, ax
    """
    if mapproj == "merc":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.Mercator()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "pc":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.PlateCarree()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "np":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.NorthPolarStereo()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "sp":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.SouthPolarStereo()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    elif mapproj == "rob":
        fig, ax = plt.subplots(
            rowscol[0],
            rowscol[1],
            subplot_kw=dict(projection=ccrs.Robinson()),
            constrained_layout=True,
            figsize=figsize,
            dpi=dpi
        )
    else:
        raise ValueError(f"Projection {mapproj} is not supported.")
    return fig, ax

In [None]:
#changed to center the levels at 0
def get_plot_levels(levels, data, lev_to_data=False):
    """Returns levels for the plot.
    Parameters
    ----------
    levels: list, numpy array
        Can be list or numpy array with three or more elements.
        If only three elements provided, they will b einterpereted as min, max, number of levels.
        If more elements provided, they will be used directly.
    data: numpy array of xarray
        Data, that should be plotted with this levels.
    lev_to_data: bool
        Switch to correct the levels to the actual data range.
        This is needed for safe plotting on triangular grid with cartopy.
    Returns
    -------
    data_levels: numpy array
        resulted levels.
    """
    if levels is not None:
        if len(levels) == 3:
            mmin, mmax, nnum = levels
            if lev_to_data:
                mmin, mmax = levels_to_data(mmin, mmax, data)
            nnum = int(nnum)
            data_levels = np.linspace(mmin, mmax, nnum)
        elif len(levels) < 3:
            raise ValueError(
                "Levels can be the list or numpy array with three or more elements."
            )
        else:
            data_levels = np.array(levels)
    else:
        mmin = np.nanmin(data)
        mmax = np.nanmax(data)
        nnum = 40
        data_levels = np.linspace(mmin, mmax, nnum)
    return data_levels

In [None]:
def levels_to_data(mmin, mmax, data):
    """Correct the levels to the actual data range.
    This is needed to make cartopy happy.
    Cartopy can't plot on triangular mesh when the color
    range is larger than the data range.
    """
    # this is needed to make cartopy happy
    mmin_d = -np.nanmax(abs(data))
    mmax_d = np.nanmax(abs(data))
    if mmin < mmin_d:
        mmin = mmin_d
        print("minimum level changed to make cartopy happy")
    if mmax > mmax_d:
        mmax = mmax_d
        print("maximum level changed to make cartopy happy")
    return mmin, mmax

In [None]:
# be_image = plt.imread('/work/ab0995/a270166/c6_EKE_files/NBMS.jfif') #NASA image background.
#added tickdist and dpi, changed lw of coasts, leave nan as nan - only works for tri plot
#from pyfesom2
def btplot(
    mesh,
    data,
    cmap=None,
    box=[-180, 180, -90, 90],
    mapproj="pc",
    levels=None,
    ptype="cf",
    units=r"$^\circ$C",
    figsize=(10, 10),
    rowscol=(1, 1),
    titles=None,
    lw=0.01,
    fontsize=12,
    box_expand=1,
    dpi=100,
    tickdist=10,
    ext='both'
):
    sfmt = ticker.ScalarFormatter(useMathText=True)
    sfmt.set_powerlimits((-3, 4))
    """Plots original field on the cartopy map using tricontourf or tripcolor.
    Parameters
    ----------
    mesh: mesh object
        FESOM2 mesh object
    data: np.array or list of np.arrays
        FESOM 2 data on nodes
        (for u,v,u_ice and v_ice one have to first interpolate
        from elements to nodes (`tonodes` function)).
        Can be ether one np.ndarray or list of np.ndarrays.
    cmap: str
        Name of the colormap from cmocean package or from the
        standard matplotlib set.
        By default `Spectral_r` will be used.
    box: list
        Map boundaries in -180 180 -90 90 format that will be used for data
        selection and plotting (default [-180 180 -89 90]).
    mapproj: str
        Map projection. Options are Mercator (merc), Plate Carree (pc),
        North Polar Stereo (np), South Polar Stereo (sp),  Robinson (rob)
    levels: list
        Levels for contour plot in format (min, max, numberOfLevels). List with more than
        3 values will be interpreted as just a list of individual level values.
        If not provided min/max values from data will be used with 40 levels.
    ptype: str
        Plot type. Options are tricontourf (\'cf\') and tripcolor (\'tri\')
    units: str
        Units for color bar.
    figsize: tuple
        figure size in inches
    rowscol: tuple
        number of rows and columns.
    titles: str or list
        Title of the plot (if string) or subplots (if list of strings)
    fontsize: float
        Font size of some of the plot elements.
    box_expand: float
        How much bigger the selected part of the mesh should be
        compared to the `box` to avoid white boundaries.
        Value is in degreed and default is 1.
    """

    if not isinstance(data, list):
        data = [data]
    if titles:
        if not isinstance(titles, list):
            titles = [titles]
        if len(titles) != len(data):
            raise ValueError(
                "The number of titles do not match the number of data fields, please adjust titles (or put to None)"
            )

    if (rowscol[0] * rowscol[1]) < len(data):
        raise ValueError(
            "Number of rows*columns is smaller than number of data fields, please adjust rowscol."
        )

    colormap = get_cmap(cmap=cmap)
    box_mesh = [box[0] - 1, box[1] + 1, box[2] - 1, box[3] + 1]

    fig, ax = create_proj_figure(mapproj, rowscol, figsize,dpi=dpi)
    if isinstance(ax, np.ndarray):
        ax = ax.flatten()
    else:
        ax = [ax]
    
    for ind, data_to_plot in enumerate(data):
        data_levels = get_plot_levels(levels, data_to_plot, lev_to_data=True)

        ax[ind].set_extent(box, crs=ccrs.PlateCarree())
        
        ax[ind].fill_between(x=[-180,180],y1=[-90],y2=[90],color='lightgrey')
        
        elem_no_nan = pf.cut_region(mesh, box_mesh)
        no_cyclic_elem2 = pf.get_no_cyclic(mesh, elem_no_nan)
        # masked values do not work in cartopy
        if data_to_plot.shape[0] == mesh.n2d:
#             data_to_plot[data_to_plot == 0] = 0 #this will not work with the cf ptype
            elem_to_plot = elem_no_nan[no_cyclic_elem2]
            
        elif data_to_plot.shape[0] == mesh.e2d:
            if ptype == "cf":
                raise ValueError(
                    "You are trying to plot data on elements using countourf, this will not work. Use `ptype='tri'` instead."
                )
            data_to_plot = data_to_plot[no_nan_triangles][no_cyclic_elem2]
            data_to_plot[data_to_plot == 0] = 0
            elem_to_plot = elem_no_nan[no_cyclic_elem2]

        if ptype == "tri":
            image = ax[ind].tripcolor(
                mesh.x2,
                mesh.y2,
                elem_to_plot,
                data_to_plot,
                transform=ccrs.PlateCarree(),
                cmap=colormap,
                vmin=data_levels[0],
                vmax=data_levels[-1],
                edgecolors="face",
                lw=lw,
                alpha=1,
            )
        elif ptype == "cf":
            image = ax[ind].tricontourf(
                mesh.x2,
                mesh.y2,
                elem_to_plot,
                data_to_plot,
                levels=data_levels,
                transform=ccrs.PlateCarree(),
                cmap=colormap,
                extend='both',
                shading='flat',
                antialiasing=True,
            )
            
        else:
            raise ValueError(
                "Only `cf` (contourf) and `tri` (tripcolor) options are supported."
            )

        ax[ind].coastlines(lw=0.5, resolution="110m", facecolor='grey')

        if titles:
            titles = titles.copy()
            ax[ind].set_title(titles.pop(0), size=20)

    for delind in range(ind + 1, len(ax)):
        fig.delaxes(ax[delind])

    tick_array=np.arange(levels[0],levels[1]+1,tickdist)
    cb = fig.colorbar(
        image, orientation="horizontal", ax=ax, pad=0.01, shrink=0.8, ticks=tick_array,extend=ext, format=sfmt
    )
    
    cb.ax.tick_params(labelsize=fontsize)

    if units:
        cb.set_label(units, size=fontsize)
    else:
        pass

    return fig,ax[0]

In [None]:
# colormap
rrcml=['#fff7fb',
'#ece2f0',
'#d0d1e6',
'#a6bddb',
'#67a9cf',
'#3690c0',
'#02818a',
'#016c59',
'#014636',]

wbg = LinearSegmentedColormap.from_list('mycmap', rrcml)
wbg.set_over(mpl.colors.to_rgba(rrcml[-1]))
wbg.set_under(mpl.colors.to_rgba(rrcml[0]))

Plot the data

In [None]:
#modeled data, difference of historic and projected means
dfig,ax=btplot(mesh,
        rmdif.where(griddes.coast!=1).values*10000, #convert from m^2/s^2 to cm^2/s^2
        levels=[-700,700,256],
        cmap='RdBu_r',
        units=r'$\Delta$'+'EKE (cm'+"$^2$"+'/s'+"$^2$"+')',
        figsize=(14,8),
        ptype='tri',
        lw=0.15,
        fontsize=20,
        dpi=800,
        tickdist=200,
        ext='both'
         )

plt.axhline(-3,color='lightgray',label='3 °S to 3 °N replaced with interpolated monthly data')
plt.axhline(3,color='lightgray')

plt.yticks(np.arange(-90,91,30),fontsize=18)
plt.xticks(np.arange(-180,181,45),fontsize=16)
plt.xlabel('Longitude (°E)',fontsize=20)
plt.ylabel('Latitude (°N)',fontsize=20)

plt.text(0.02,0.95,'c',fontsize=24,weight='bold',horizontalalignment='left', verticalalignment='center', 
                   transform=ax.transAxes,)

plt.legend(fontsize=18,loc=4)

dfig.savefig('Figure_2c.png',bbox_inches='tight')

In [None]:
# observational data, observational period
ofig,ax=btplot(mesh,
      obsm.where(griddes.coast!=1)*10000, #convert m^2/s^2 to cm^2/s^2
      ptype='tri',
      mapproj='pc',
      lw=0.15,
      levels=[0,1500,41],
      cmap=wbg,
      units='EKE (cm'+"$^2$"+'/s'+"$^2$"+')',
      figsize=(14,8),
      fontsize=20,
      dpi=800,
      tickdist=500,
    ext='max'
    )

plt.yticks(np.arange(-90,91,30),fontsize=18)
plt.xticks(np.arange(-180,181,45),fontsize=16)
plt.xlabel('Longitude (°E)',fontsize=20)
plt.ylabel('Latitude (°N)',fontsize=20)

plt.text(0.02,0.95,'b',fontsize=26,weight='bold',horizontalalignment='left', verticalalignment='center', 
                   transform=ax.transAxes,)

ofig.savefig('Figue_2b.png')

In [None]:
#modeled data, observational period
mfig,ax=btplot(mesh,
      modm.where(griddes.coast!=1)*10000, #convert m^2/s^2 to cm^2/s^2
      ptype='tri',
      mapproj='pc',
      lw=0.15,
      levels=[0,1500,41],
      cmap=wbg,
      units='EKE (cm'+"$^2$"+'/s'+"$^2$"+')',
      figsize=(14,8),
      fontsize=20,
      dpi=800,
      tickdist=500,
    ext='max'
    )
plt.axhline(-3,color='lightgray',label='3 °S to 3 °N replaced with interpolated monthly data')
plt.axhline(3,color='lightgray')

plt.yticks(np.arange(-90,91,30),fontsize=18)
plt.xticks(np.arange(-180,181,45),fontsize=16)
plt.xlabel('Longitude (°E)',fontsize=20)
plt.ylabel('Latitude (°N)',fontsize=20)

plt.text(0.02,0.95,'a',fontsize=24,weight='bold',horizontalalignment='left', verticalalignment='center', 
                   transform=ax.transAxes,)
plt.legend(fontsize=18,loc=4)

mfig.savefig('Figure_2a.png')