# Nissen et al., 2023: Severe 21st-century OA in Antarctic MPAs
#
# script to save surface-referenced potential density as netcdf file
# requires full model output
# note that some paths are hard-coded throughout the script

In [None]:
#!jupyter nbconvert --to script plot_PAPER2_rho0_fields_save_as_netcdf.ipynb

In [1]:
import sys
import os
sys.path.append("../pyfesom/") # add pyfesom to search path
sys.path.append("../python_gsw_py3/") 
import pyfesom as pf
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from matplotlib import cm
from netCDF4 import Dataset, MFDataset
import pandas as pd
from gsw import rho # rho from SA, CT, p
from gsw import pt0_from_t # potTemp from SA, t, p (at reference pressure 0)
from gsw import pt_from_t # potTemp from SA, t, p and reference pressure
from gsw import pot_rho_t_exact # potRho from SA, t, p and reference pressure
from gsw import p_from_z # get pressure from z and lat
from gsw import CT_from_pt # conservative temp from potTemp and SA
from gsw import rho_first_derivatives # first derivatives of rho with respect to SA, CT, p
from gsw2 import sigma0_pt0_exact # gsw_sigma0_pt0_exact(SA,pt0)
#from gsw import alpha_wrt_CT_t_exact #alpha_wrt_CT_t_exact(SA, t, p)
#from gsw import alpha_wrt_pt_t_exact #alpha_wrt_CT_t_exact(SA, t, p)
#from gsw import beta_const_pt_t_exact #beta_const_pt_t_exact(SA, t, p)
from gsw import SA_from_SP
from numba import njit
from tqdm import tqdm

In [None]:
#
# from /global/homes/c/cnissen/scripts/python_gsw_py3/gsw/gibbs//thermodynamics_from_t.py

load_gsw_functions = False
if load_gsw_functions:
    sfac = 0.0248826675584615
    """
    sfac = 1 / (40 * uPS) = 1 / (40. * (SSO / 35.))
    """

    SSO = 35.16504
    """
    SSO is the Standard Ocean Reference Salinity (35.16504 g/kg.)

    SSO is the best estimate of the Absolute Salinity of Standard Seawater
    when the seawater sample has a Practical Salinity, SP, of 35
    (Millero et al., 2008), and this number is a fundamental part of the
    TEOS-10 definition of seawater.
    """

    def sigma0_pt0_exact(SA, pt0):
        """
        Calculates potential density anomaly with reference sea pressure of
        zero (0) dbar.  The temperature input to this function is potential
        temperature referenced to zero dbar.

        Parameters
        ----------
        SA : array_like
             Absolute salinity [g kg :sup:`-1`]
        pt0 : array_like
              potential temperature [:math:`^\circ` C (ITS-90)]
              with respect to a reference sea pressure of 0 dbar

        Returns
        -------
        sigma0_pt0_exact : array_like
                           potential density anomaly [kg m :sup:`-3`]
                           respect to a reference pressure of 0 dbar

        Examples
        --------
        >>> import gsw
        >>> SA = [34.7118, 34.8915, 35.0256, 34.8472, 34.7366, 34.7324]
        >>> pt0 = [28.7832, 28.4209, 22.7850, 10.2305, 6.8292, 4.3245]
        >>> gsw.sigma0_pt0_exact(SA, pt0)
        array([ 21.79814475,  22.05251193,  23.89356369,  26.66762521,
                27.10723499,  27.4096324 ])

        References
        ----------
        .. [1] IOC, SCOR and IAPSO, 2010: The international thermodynamic equation
           of seawater - 2010: Calculation and use of thermodynamic properties.
           Intergovernmental Oceanographic Commission, Manuals and Guides No. 56,
           UNESCO (English), 196 pp. See Eqn. (3.6.1).
        """

        SA = np.maximum(SA, 0)  # Ensure that SA is non-negative.
        x2 = sfac * SA
        x = np.sqrt(x2)
        y = pt0 * 0.025
        g03 = (100015.695367145 +
               y * (-270.983805184062 +
               y * (1455.0364540468 +
               y * (-672.50778314507 +
               y * (397.968445406972 +
               y * (-194.618310617595 +
               y * (63.5113936641785 -
               y * 9.63108119393062)))))))
        g08 = x2 * (-3310.49154044839 +
                    x * (199.459603073901 +
                    x * (-54.7919133532887 +
                    x * 36.0284195611086 -
                    y * 22.6683558512829) +
                    y * (-175.292041186547 +
                    y * (383.058066002476 +
                    y * (-460.319931801257 +
                    y * 234.565187611355)))) +
                    y * (729.116529735046 +
                    y * (-860.764303783977 +
                    y * (694.244814133268 +
                    y * (-297.728741987187)))))

        # The above code is exactly the same as the following two lines of code.
        # sigma0_pt_exact = rho_t_exact(SA, pt0, 0.) - 1000

        return 100000000. / (g03 + g08) - 1000.0

    def check_input(SP, p, lon, lat):
        """
        Check for out of range values.
        """

        # Helper for the "from_SP" functions.
        lon, lat, p, SP = np.broadcast_arrays(lon, lat, p, SP, subok=True)

        cond1 = ((p < 100) & (SP > 120))
        cond2 = ((p >= 100) & (SP > 42))
        if cond1.any() or cond2.any():  # don't modify input array
            mask = np.ma.filled(cond1, False) | np.ma.filled(cond2, False)
            SP = np.ma.array(SP, mask=mask)

        lon = lon % 360

        # FIXME: If we do keep the checks below, they need to
        # be reformulated with ValueError('pressure out of range') etc.
        # The original also checks for 9999s--a fossil from old-time
        # Fortran days.

        # I don't think we need these here; if any such checking is
        # needed, it should not just be for the "from_SP" functions.
        if False:
            if ((p < -1.5) | (p > 12000)).any():
                raise Exception('Sstar_from_SP: pressure is out of range')
            if ((lon < 0) | (lon > 360)).any():
                raise Exception('Sstar_from_SP: longitude is out of range')
            if (np.abs(lat) > 90).any():
                raise Exception('Sstar_from_SP: latitude is out of range')

        SP = np.maximum(SP, 0)  # Works on masked array also.

        return SP, p, lon, lat

    def SA_from_SP(SP, p, lon, lat):
        """Calculates Absolute Salinity from Practical Salinity.

        Parameters
        ----------
        SP : array_like
             salinity (PSS-78) [unitless]
        p : array_like
            pressure [dbar]
        lon : array_like
              decimal degrees east [0..+360] or [-180..+180]
        lat : array_like
              decimal degrees (+ve N, -ve S) [-90..+90]

        Returns
        -------
        SA : masked array
             Absolute salinity [g kg :sup:`-1`]

        Notes
        -----
        The mask is only set when the observation is well and truly on dry
        land; often the warning flag is not set until one is several hundred
        kilometers inland from the coast.

        Since SP is non-negative by definition, this function changes any negative
        input values of SP to be zero.

        Examples
        --------
        >>> import gsw
        >>> SP = [34.5487, 34.7275, 34.8605, 34.6810, 34.5680, 34.5600]
        >>> p = [10, 50, 125, 250, 600, 1000]
        >>> lon = 188
        >>> lat = 4
        >>> gsw.SA_from_SP(SP, p, lon, lat)
        array([ 34.71177834,  34.89152262,  35.02554486,  34.84722903,
                34.73662847,  34.73236307])

        References
        ----------
        .. [1] IOC, SCOR and IAPSO, 2010: The international thermodynamic equation
           of seawater - 2010: Calculation and use of thermodynamic properties.
           Intergovernmental Oceanographic Commission, Manuals and Guides No. 56,
           UNESCO (English), 196 pp. See section 2.5 and appendices A.4 and A.5.

        .. [2] McDougall, T.J., D.R. Jackett and F.J. Millero, 2010: An algorithm
           for estimating Absolute Salinity in the global ocean. Submitted to Ocean
           Science. A preliminary version is available at Ocean Sci. Discuss.,
           6, 215-242.
           http://www.ocean-sci-discuss.net/6/215/2009/osd-6-215-2009-print.pdf
        """

        SP, p, lon, lat = check_input(SP, p, lon, lat)

        SA = (SSO / 35) * SP * (1 + SAAR(p, lon, lat)[0])
        SA_baltic = SA_from_SP_Baltic(SP, lon, lat)

        # The following function (SAAR) finds SAAR in the non-Baltic parts of
        # the world ocean.  (Actually, this SAAR look-up table returns values
        # of zero in the Baltic Sea since SAAR in the Baltic is a function of SP,
        # not space.
        if SA_baltic is not None:
            SA[~SA_baltic.mask] = SA_baltic[~SA_baltic.mask]

        return SA

    h006 = -2.1078768810e-9
    h007 =  2.8019291329e-10


    class SA_table(object):
        """
        TODO: Write docstring.
        """
        # Central America barrier
        x_ca = np.array([260.0, 272.59, 276.5, 278.65, 280.73, 295.217])
        y_ca = np.array([19.55, 13.97, 9.6, 8.1, 9.33, 0.0])

        def __init__(self, fname="/global/homes/c/cnissen/scripts/python_gsw_py3/gsw/utilities/data/gsw_data_v3_0.npz", max_p_fudge=10000,
                     min_frac=0):
            self.fname = fname
            self.max_p_fudge = max_p_fudge
            self.min_frac = min_frac
            data = read_data(fname)
            self.lon = data.longs_ref.astype(np.float)
            self.lat = data.lats_ref.astype(np.float)
            self.p = data.p_ref                # Depth levels
            self.dlon = self.lon[1] - self.lon[0]
            self.dlat = self.lat[1] - self.lat[0]
            self.i_ca, self.j_ca = self.xy_to_ij(self.x_ca, self.y_ca)
            # Make the order x, y, z:
            # Start with deltaSA_ref (was delta_SA_ref in V2):
            temp = data.deltaSA_ref.transpose((2, 1, 0)).copy()
            self.dsa_ref = np.ma.masked_invalid(temp)
            self.dsa_ref.data[self.dsa_ref.mask] = 0
            # Now SAAR_ref, which did not exist in V2:
            temp = data.SAAR_ref.transpose((2, 1, 0)).copy()
            self.SAAR_ref = np.ma.masked_invalid(temp)
            self.SAAR_ref.data[self.SAAR_ref.mask] = 0

        def xy_to_ij(self, x, y):
            """
            Convert from lat/lon to grid index coordinates,
            without truncation or rounding.
            """
            i = (x - self.lon[0]) / self.dlon
            j = (y - self.lat[0]) / self.dlat
            return i, j

        def _central_america(self, di, dj, ii, jj, gm):
            """
            Use a line running through Central America to zero
            the goodmask for grid points in the Pacific forming
            the grid box around input locations in the Atlantic,
            and vice-versa.
            """
            ix, jy = ii[0] + di, jj[0] + dj  # Reconstruction: minor inefficiency.
            inear = ((ix >= self.i_ca[0]) & (ix <= self.i_ca[-1])
                     & (jy >= self.j_ca[-1]) & (jy <= self.j_ca[0]))
            if not inear.any():
                return gm
            inear_ind = inear.nonzero()[0]
            ix = ix[inear]
            jy = jy[inear]
            ii = ii[:, inear]
            jj = jj[:, inear]
            jy_ca = np.interp(ix, self.i_ca, self.j_ca)
            above = jy - jy_ca  # > 0 if input point is above dividing line
            # Intersections of left and right grid lines with dividing line
            jleft_ca = np.interp(ii[0], self.i_ca, self.j_ca)
            jright_ca = np.interp(ii[1], self.i_ca, self.j_ca)
            jgrid_ca = [jleft_ca, jright_ca, jright_ca, jleft_ca]
            # Zero the goodmask for grid points on opposite side of divider
            for i in range(4):
                opposite = (above * (jj[i] - jgrid_ca[i])) < 0
                gm[i, inear_ind[opposite]] = 0
            return gm

        def xy_interp(self, di, dj, ii, jj, k):
            """
            2-D interpolation, bilinear if all 4 surrounding
            grid points are present, but treating missing points
            as having the average value of the remaining grid
            points. This matches the matlab V2 behavior.
            """
            # Array of weights, CCW around the grid box
            w = np.vstack(((1 - di) * (1 - dj),  # lower left
                          di * (1 - dj),         # lower right
                          di * dj,               # upper right
                          (1 - di) * dj))        # upper left
            gm = ~self.dsa.mask[ii, jj, k]   # gm is "goodmask"
            gm = self._central_america(di, dj, ii, jj, gm)
            # Save a measure of real interpolation quality.
            frac = (w * gm).sum(axis=0)
            # Now loosen the interpolation, allowing a value to
            # be calculated on a grid point that is masked.
            # This matches the matlab gsw version 2 behavior.
            jm_partial = gm.any(axis=0) & (~(gm.all(axis=0)))
            # The weights of the unmasked points will be increased
            # by the sum of the weights of the masked points divided
            # by the number of unmasked points in the grid square.
            # This is equivalent to setting the masked data values
            # to the average of the unmasked values, and then
            # unmasking, which is the matlab v2 implementation.
            if jm_partial.any():
                w_bad = w * (~gm)
                w[:, jm_partial] += (w_bad[:, jm_partial].sum(axis=0) /
                                     gm[:, jm_partial].sum(axis=0))
            w *= gm
            wsum = w.sum(axis=0)
            valid = wsum > 0  # Only need to prevent division by zero here.
            w[:, valid] /= wsum[valid]
            w[:, ~valid] = 0
            vv = self.dsa.data[ii, jj, k]
            vv *= w
            dsa = vv.sum(axis=0)
            return dsa, frac

        def _delta_SA(self, p, lon, lat):
            """
            Table lookup engine--to be called only from SAAR or SA_ref.
            """
            p = np.ma.masked_less(p, 0)
            mask_in = np.ma.mask_or(np.ma.getmask(p), np.ma.getmask(lon))
            mask_in = np.ma.mask_or(mask_in, np.ma.getmask(lat))
            p, lon, lat = [np.ma.filled(a, 0).astype(float) for a in (p, lon, lat)]
            p, lon, lat = np.broadcast_arrays(p, lon, lat)
            if p.ndim > 1:
                shape_in = p.shape
                p, lon, lat = list(map(np.ravel, (p, lon, lat)))
                reshaped = True
            else:
                reshaped = False
            p_orig = p.copy()  # Save for comparison to clipped p.
            ix0, iy0 = self.xy_to_ij(lon, lat)
            i0raw = np.floor(ix0).astype(int)
            i0 = np.clip(i0raw, 0, len(self.lon) - 2)
            di = ix0 - i0
            j0raw = np.floor(iy0).astype(int)
            j0 = np.clip(j0raw, 0, len(self.lat) - 2)
            dj = iy0 - j0
            # Start at lower left and go CCW; match order in _xy_interp.
            ii = np.vstack((i0, i0 + 1, i0 + 1, i0))
            jj = np.vstack((j0, j0, j0 + 1, j0 + 1))
            k1 = np.searchsorted(self.p, p, side='right')
            # Clip p and k1 at max p of grid cell.
            kmax = (self.ndepth[ii, jj].max(axis=0) - 1)
            mask_out = kmax.mask
            kmax = kmax.filled(1)
            clip_p = (p >= self.p[kmax])
            p[clip_p] = self.p[kmax[clip_p]]
            k1[clip_p] = kmax[clip_p]
            k0 = k1 - 1
            dsa0, frac0 = self.xy_interp(di, dj, ii, jj, k0)
            dsa1, frac1 = self.xy_interp(di, dj, ii, jj, k1)
            dp = np.diff(self.p)
            pfrac = (p - self.p[k0]) / dp[k0]
            delta_SA = dsa0 * (1 - pfrac) + dsa1 * pfrac
            # Save intermediate results in case we are curious about
            # them; the frac values are most likely to be useful.
            # We won't bother to reshape them, though, and we may
            # delete them later.
            self.dsa0 = dsa0
            self.frac0 = frac0
            self.dsa1 = dsa1
            self.frac1 = frac1
            self.pfrac = pfrac
            self.p_fudge = p_orig - p
            # Editing options, in case we don't want to use
            # values calculated from the wrong pressure, or from
            # an incomplete SA table grid square.
            # mask_out |= self.p_fudge > self.max_p_fudge
            # mask_out |= self.frac1 < self.min_frac
            # delta_SA = np.ma.array(delta_SA, mask=mask_out, copy=False)
            # Later on, it is expected to be a masked array.
            delta_SA = np.ma.array(delta_SA, copy=False)
            if reshaped:
                delta_SA.shape = shape_in
                self.p_fudge.shape = shape_in
            if mask_in is not np.ma.nomask:
                delta_SA = np.ma.array(delta_SA, mask=mask_in, copy=False)
            return delta_SA

        def SAAR(self, p, lon, lat):
            """
            Table lookup of salinity anomaly ratio, given pressure, lon, and lat.
            """
            self.dsa = self.SAAR_ref
            # In V2,
            # ndepth from the file disagrees with the unmasked count from
            # SAAR_ref in a few places; this should be fixed in the
            # file, but for now we will simply calculate ndepth directly from
            # SAAR_ref.
            # TODO: check to see whether this discrepancy is also found in V3.
            # TODO: check: do we even need to calculate ndepth? It doesn't
            #       appear to be used for anything.
            # self.ndepth = np.ma.masked_invalid(data.ndepth_ref.T).astype(np.int8)
            ndepth = self.dsa.count(axis=-1)
            self.ndepth = np.ma.masked_equal(ndepth, 0)
            return self._delta_SA(p, lon, lat)

        def delta_SA_ref(self, p, lon, lat):
            """
            Table lookup of salinity anomaly reference value, given pressure,
            lon, and lat.
            """
            self.dsa = self.dsa_ref
            # See comment in previous method.
            ndepth = self.dsa.count(axis=-1)
            self.ndepth = np.ma.masked_equal(ndepth, 0)
            return self._delta_SA(p, lon, lat)

    def SAAR(p, lon, lat):
        """
        Absolute Salinity Anomaly Ratio (excluding the Baltic Sea).
        Calculates the Absolute Salinity Anomaly Ratio, SAAR, in the open ocean
        by spatially interpolating the global reference data set of SAAR to the
        location of the seawater sample.
        This function uses version 3.0 of the SAAR look up table.

        Parameters
        ----------
        p : array_like
            pressure [dbar]
        lon : array_like
              decimal degrees east (will be treated modulo 360)
        lat : array_like
              decimal degrees (+ve N, -ve S) [-90..+90]

        Returns
        -------
        SAAR : array
               Absolute Salinity Anomaly Ratio [unitless]
        in_ocean : boolean array

        Notes
        -----
        The Absolute Salinity Anomaly Ratio in the Baltic Sea is evaluated
        separately, since it is a function of Practical Salinity, not of space.
        The present function returns a SAAR of zero for data in the Baltic Sea.
        The correct way of calculating Absolute Salinity in the Baltic Sea is by
        calling SA_from_SP.
        The in_ocean flag is only set when the observation is well and truly on dry
        land; often the warning flag is not set until one is several hundred
        kilometers inland from the coast.

        The algorithm is taken from the matlab implementation of the references,
        but the numpy implementation here differs substantially from the
        matlab implementation.

        References
        ----------
        .. [1] IOC, SCOR and IAPSO, 2010: The international thermodynamic equation
           of seawater - 2010: Calculation and use of thermodynamic properties.
           Intergovernmental Oceanographic Commission, Manuals and Guides No. 56,
           UNESCO (English), 196 pp.

        .. [2] McDougall, T.J., D.R. Jackett and F.J. Millero, 2010: An algorithm
           for estimating Absolute Salinity in the global ocean.  Submitted to
           Ocean Science. A preliminary version is available at Ocean Sci.
           Discuss., 6, 215-242.
           http://www.ocean-sci-discuss.net/6/215/2009/osd-6-215-2009-print.pdf
        """

        saar = SA_table().SAAR(p, lon, lat)
        return saar, ~saar.mask

    def read_data(fname, datadir=None):
        """
        Read variables from a numpy '.npz' file into a minimal class providing
        attribute access.  A cache is used to avoid re-reading the same file.
        """
        return _npz_cache(fname, datadir=datadir)

    class Cache_npz(object):
        def __init__(self):
            self._cache = dict()
            self._default_path = os.path.join(os.path.dirname(__file__), 'data')

        def __call__(self, fname, datadir=None):
            if datadir is None:
                datadir = self._default_path
            fpath = os.path.join(datadir, fname)
            try:
                return self._cache[fpath]
            except KeyError:
                pass
            d = np.load(fpath)
            ret = Bunch(d)
            self._cache[fpath] = ret
            return ret

    _npz_cache = Cache_npz()


In [None]:
#---
# some settings
#-----

meshpath = '/pscratch/sd/c/cnissen/mesh_COARZE/'
mesh = pf.load_mesh(meshpath, get3d=True,usepickle=False)
#print(mesh.n32)

year_list    = np.arange(1990,2100+1,1) #np.arange(2010,2089+1,1)
print (year_list)
    
ref_pressure = 0 # currently, the script has been set up for 0 & 1000 & 2000


In [None]:
#---
# load mesh info
#---

path_mesh = '/pscratch/sd/c/cnissen/'
file_mesh = 'Nissen2022_FESOM_REcoM_mesh_information_corrected_20220910.nc'

f1 = Dataset(path_mesh+file_mesh) #xr.open_dataset(path+file1)
lats      = f1.variables['lat'][:]
lons      = f1.variables['lon'][:]
zlevs    = f1.variables['zlevs'][:]
cavities = f1.variables['cavity'][:]
topo     = f1.variables['topo'][:]
area_nodes     = f1.variables['cell_area'][:]
volume   = f1.variables['cell_volume'][:]
f1.close()
print(lats.shape)

ind_no_cavity = np.where(cavities==0)[0]
ind_cavities = np.where(cavities==1)[0]

#
df = pd.read_csv('/pscratch/sd/c/cnissen/HLRN_runs_postprocessed/nod3d.out', delim_whitespace=True, skiprows=1, \
                        names=['node_number','x','y','z','flag'])
lats3d    = df.y.values
lons3d    = df.x.values
zlevs3d = df.z.values
print (zlevs3d.shape)


In [None]:
#-----
# get pressure array and convert to depths x nodes
#-----

if np.min(zlevs3d)<-10:
    p = p_from_z(zlevs3d,lats3d)
elif np.min(zlevs3d)>=0:
    p = p_from_z(-1*zlevs3d,lats3d) # to make sure that pressure is >0
print (np.min(p),np.max(p),p.shape)  

# convert mesh.n3d to depth x mesh.n2d
# convert mesh.n3d to depth x mesh.n2d
depth = np.unique(-1*zlevs)
p_2d   = np.nan*np.ones([len(depth),len(lons)])
lat_2d = np.nan*np.ones([len(depth),len(lons)]) #-999*np.ones
lon_2d = np.nan*np.ones([len(depth),len(lons)])
for dd in tqdm(range(0,len(depth))): # loop over depths
    for ii in range(0,len(lons)): # loop over surface nodes
        if mesh.n32[ii,dd]>0:
            #volume_levelwise[dd,ii] = volume_nodes[mesh.n32[ii,dd]-1]
            p_2d[dd,ii]   = p[mesh.n32[ii,dd]-1]
            lat_2d[dd,ii] = lats3d[mesh.n32[ii,dd]-1]
            lon_2d[dd,ii] = lons3d[mesh.n32[ii,dd]-1]
            

In [None]:
#----
# function
#----

@njit
def reorganize_field_in_cavities(ind_cavities,data): 
    for ii in ind_cavities:  # [72408]
        bb = data[:,ii] # get all depth levels at current cavity node
        ind_av = np.where(bb>0)[0] # get indices of all depth levels that are NOT masked
        ##print bb
        #print ind_av
        #ind_av = bb>=0 #bb.mask==False
        #nd_av = np.where(ind_av==True)[0]
        #print ind_av
        # if surface value is filled, but thereafter there is a gap: 
        if len(ind_av)>1:
            if (ind_av[1]-ind_av[0])>1:  #any(np.diff(ind_av)>1):  
                bb[ind_av[1]-1]=bb[ind_av[0]] # move "surface" value to correct depth
                bb[ind_av[0]] = 0 # set surface entry to zero
               
        data[:,ii] = bb # overwrite original field
    return data

@njit
def reorganize_pressure_field_in_cavities(ind_cavities,data,data_pressure): 
    #------
    # NOTE: "data" should be the data that are already reorganized in cavities!!!!
    #       "data" should not have masked values, but zeros where there is not data
    #------
    for ii in ind_cavities:  #[72408]:
        cc = data_pressure[:,ii]
        bb = data[:,ii] # get all depth levels at current cavity node
        ind_av = np.where(bb>0)[0] # get indices of all depth levels that are NOT masked
        #print ind_av
        cc_aux = np.zeros_like(cc)
        cc_aux[ind_av] = cc[0:len(ind_av)]
        data_pressure[:,ii] = cc_aux # overwrite original field
        
    return data_pressure


In [None]:
#------
# get monthly density for each year
#------

correct_pressure_array = True
for yy in range(0,len(year_list)):
    year = year_list[yy]
    print ('Load output for year '+str(year)+'...')
    
    if year<2015: 
        path1  = '/pscratch/sd/c/cnissen/COARZE_temp/hist/'
        path2  = '/pscratch/sd/c/cnissen/COARZE_salt/hist/'
    else: 
        path1  = '/pscratch/sd/c/cnissen/COARZE_temp/ssp585/'
        path2  = '/pscratch/sd/c/cnissen/COARZE_salt/ssp585/'
        
    temp_file    = Dataset(path1+'thetao_fesom_'+str(year)+'0101.nc') 
    salt_file    = Dataset(path2+'so_fesom_'+str(year)+'0101.nc') 
    
    for mm in tqdm(range(0,12)):
        
        # load data 
        temp    = temp_file.variables['thetao'][mm,:,:] # potential temp, depth x nodes
        salt_sp = salt_file.variables['so'][mm,:,:] # salinity, , depth x nodes
        
        #------
        # calculate density, reference pressure 0
        #------
        # set masked values to 0 to get correction within cavity correct
        # (if I don't do that, masked and not-masked values are not correctly recognized with njit)
        salt_sp[salt_sp.mask==True]=0 
        temp[temp.mask==True]=0 
        # move "surface" value in cavities to correct depth
        salt_sp = reorganize_field_in_cavities(ind_cavities,salt_sp) 
        temp    = reorganize_field_in_cavities(ind_cavities,temp)
        if correct_pressure_array:
            # correct pressure array 
            print ('Correct pressure array in cavities')
            p_2d[np.isnan(p_2d)]=0
            lon_2d[np.isnan(lon_2d)]=0
            lat_2d[np.isnan(lat_2d)]=0
            # (temp is used to extract info at what depth levels pressure should be, temp should not be masked!)
            p_2d = reorganize_pressure_field_in_cavities(ind_cavities,salt_sp,p_2d)
            lon_2d = reorganize_pressure_field_in_cavities(ind_cavities,salt_sp,lon_2d)
            lat_2d = reorganize_pressure_field_in_cavities(ind_cavities,salt_sp,lat_2d)
            # set zeros back to masked
            p_2d = np.ma.masked_where(p_2d==0,p_2d)
            p_2d[0,:] = 0 # surface layer should actually have zero pressure!!
            #
            # BUG fix Dec 2023: in the lines below, the lon_2d/lat_2d fields were overwritten with pressure fields!!!!
            # re-calculate rho fields for all the years!
            #
            lon_2d = np.ma.masked_where(lon_2d==0,lon_2d) #p_2d)
            lat_2d = np.ma.masked_where(lat_2d==0,lat_2d) #p_2d)
            correct_pressure_array = False # set to False as I only need to do this once!
        
        # set zeros back to masked
        salt_sp = np.ma.masked_where(salt_sp==0,salt_sp) 
        temp = np.ma.masked_where(temp==0,temp)
           
        salt_abs = np.zeros_like(salt_sp)
        for dd in range(0,salt_sp.shape[0]): # loop over depths
            salt_abs[dd,:] = SA_from_SP(salt_sp[dd,:], p_2d[dd,:], lon_2d[dd,:],lat_2d[dd,:]) # the resulting salt_abs field looks weird!
        
        sigma = sigma0_pt0_exact(salt_abs,temp)  # from abs salinity and pot temp
        rho = np.squeeze(sigma) + 1000
        #del sigma,salt_abs,temp,salt_sp
        
        # save as netcdf
        save_to_netcdf = True
        if save_to_netcdf:
            
            vari = 'rho0'

            savepath = '/pscratch/sd/c/cnissen/HLRN_runs_postprocessed/PAPER2_postprocessed/rho0_fields/'
            source = '/global/homes/c/cnissen/scripts/plot_FESOM_rho0_fields_save_as_netcdf.ipynb'

            netcdf_name = vari+'_fesom_'+str(year)+'0101.nc'

            if not os.path.exists(savepath+netcdf_name):
                print ('Create file '+savepath+netcdf_name)
                w_nc_fid = Dataset(savepath+netcdf_name, 'w', format='NETCDF4_CLASSIC')
                # create dimension & variable
                w_nc_fid.createDimension('nodes_2d', mesh.n2d)
                w_nc_fid.createDimension('depth', len(np.unique(zlevs)))
                w_nc_fid.createDimension('time', 12)
                w_nc_fid.script = source
                w_nc_fid.close()

            w_nc_fid = Dataset(savepath+netcdf_name, 'r+', format='NETCDF4_CLASSIC')      # Create and open new netcdf file to write to
            try:
                w_nc_var1 = w_nc_fid.createVariable(vari, 'f4',('time','depth','nodes_2d'))
                w_nc_var1.units = 'kg m-3' 
                w_nc_var1.description = 'Surface-referenced potential density'
            except:
                pass

            w_nc_fid.variables[vari][mm,:,:] = rho
                        
            w_nc_fid.close() 

            if mm==11:
                print ('Successfully saved '+vari+' for all months in year '+str(year))
        #del rho

print ('done')


In [None]:
#---
# test plot
#---
test_plot = False
if test_plot:
    print(rho.shape)

    dpicnt = 150

    print('Min/Max p_2d:',np.min(p_2d),np.max(p_2d))
    fig7= plt.figure(num=18, figsize=(7,3), dpi=dpicnt, facecolor='w', edgecolor='k')     
    plt.contourf(p_2d,extend='both',cmap=cm.inferno,levels=np.arange(0,6000+100,100))#,levels=np.arange()) 
    plt.colorbar()
    plt.show()

    fig7= plt.figure(num=18, figsize=(7,3), dpi=dpicnt, facecolor='w', edgecolor='k')     
    plt.contourf(temp,extend='both',cmap=cm.inferno)#,levels=np.arange()) 
    plt.show()

    fig7= plt.figure(num=18, figsize=(7,3), dpi=dpicnt, facecolor='w', edgecolor='k')     
    plt.contourf(salt_sp,extend='both',cmap=cm.inferno)#,levels=np.arange()) 
    plt.colorbar()
    plt.show()

    fig7= plt.figure(num=18, figsize=(7,3), dpi=dpicnt, facecolor='w', edgecolor='k')     
    plt.contourf(salt_abs,extend='both',cmap=cm.inferno)#,levels=np.arange()) 
    plt.show()
