In [None]:
import numpy as np
import xarray as xr
import dask
import seawater as sw
import xgcm

In [None]:
path = "/path/to/model/output/" 
eddypath = "/path/to/tracked/eddies/"

## Calculate the different forms of heat transports on z-levels

The zonally integrated, time-averaged total HT is defined as  

$HT(y, z) = \rho C_{p} [\overline{v\, (T - T_{f})}]$,  

where $\rho = 1035$ kg$\,$m$^{-3}$ is the model's reference density, $C_{p} = 3994$ J$\,$kg$^{-1}\,$K$^{-1}$ is the specific heat capacity of seawater at constant pressure, $v$ is meridional velocity, $T$ is the model's potential temperature and $T_{f}$ is the freezing point temperature. The use of $T_{f}$ ensures that HT remains positive, whenever the meridional velocity is positive. The square brackets $[\cdot]$ indicate zonal integration over the entire domain ($\oint\,\,dx$) and the overbar the decadal mean. 

The mean HT, $MHT$, is based on a monthly climatology computed over each decade as $MHT(y, z) = \rho C_{p} [\overline{\langle v\, \rangle \langle (T - T_{f}) \rangle}]$, where $\langle \cdot \rangle$ indicates the monthly climatology. This allows to restrict the calculations to single seasons if desired, *e.g.* $MHT_{JJA}(y, z) = \rho C_{p} [\overline{\langle v\, \rangle \langle (T - T_{f}) \rangle}^{JJA}]$, where $\overline{\cdot}^{JJA}$ represents the decadal mean over June-July-August. The largest contribution to MHT is expected to be that of the surface Ekman transport which can be computed as $MHT^{Ek}(y) = \rho C_{p} [\overline{\langle v^{Ek}\, \rangle \langle (T^{Ek} - T_{f}) \rangle}]$, where $T^{Ek}$ is the temperature averaged over the surface Ekman layer (assumed to be $50$ m deep) and $v^{Ek} = \frac{\tau_{x}}{\rho\,f}$ is the meridional Ekman transport with $\tau_{x}$ being the zonal wind stress and $f$ the Coriolis parameter.

The transient part of the meridional HT is calculated as  

$THT(y, z) = \rho C_{p} [\overline{v'T'}]$,  

where primes indicate the deviation from the monthly climatology such that $v' = v - \langle v \rangle$ and $T' = T - \langle T \rangle$. 

The contribution of coherent eddies to transient heat transport ($THT$) is considered by multiplying $THT$ with the mask $M^{CME}$ before integrating. In this way only the $THT$ that occurs within coherent mesoscale eddies contributes to the integrals. 

$THT^{CME}(y, z) = \rho C_{p} [\overline{v'T'\,M^{CME}}]$

First, a function to load the model output and eddymasks

In [None]:
def open_data(path, eddypath, y1, y2):
    data = xr.open_zarr(path + "zarr_Diags/output.5d.zarr").sel(time=slice(y1 + "-01-01", 
                                                                           y2 + "-12-30"))
    THTdiv = xr.open_mfdataset(path + "post/THTvdiv.0???.nc").sel(time=slice(y1 + "-01-01", 
                                                                             y2 + "-12-30"))
    em = xr.open_mfdataset(eddypath
        + 'eddymask_0201-0300.nc').eddymask_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    em_cyclones = xr.open_mfdataset(eddypath 
        + 'eddymask_cyclones_0201-0300.nc').eddymask_cyclones_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    em_anticyclones = xr.open_mfdataset(eddypath 
        + 'eddymask_anticyclones_0201-0300.nc').eddymask_anticyclones_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    tm = xr.open_mfdataset(eddypath 
        + 'trackmask_0201-0300.nc').trackmask_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    tm_cyclones = xr.open_mfdataset(eddypath 
        + 'trackmask_cyclones_0201-0300.nc').trackmask_cyclones_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    tm_anticyclones = xr.open_mfdataset(eddypath 
        + 'trackmask_anticyclones_0201-0300.nc').trackmask_anticyclones_binary.squeeze().rename({"lon": "XG", "lat": "YG"})
    return data, THTdiv, em, em_cyclones, em_anticyclones, tm, tm_cyclones, tm_anticyclones

A function to create the `xgcm.Grid()` instance for interpolation, set constants needed for the calculation of heat transport and prepare some masks to exclude the shelf (`maskShelf`) and select only the bottom grid cells (`bottom`)

In [None]:
def param_and_prepare(data, THTdiv):
    # create `grid`
    metrics = {
        ('X'): ['dxC', 'dxG', 'dxF', 'dxV'], # X distances
        ('Y'): ['dyC', 'dyG', 'dyF', 'dyU'], # Y distances
        ('Z'): ['drF', 'drW', 'drS', 'drC'], # Z distances
        ('X', 'Y'): ['rAw', 'rAs', 'rA', 'rAz'] # Areas in x-y plane
        }
    grid = xgcm.Grid(data, periodic=["X"], metrics=metrics)
    # set constants
    rho = 1035. # reference density
    Cp = 3994. # heat capacity
    f0 = -1.405e-4 # reference Coriolis
    beta = 1.145e-11 # beta
    Cd = 5.0e-2 # drag coefficient for calculation of bottom Ekman transport
    f = f0 + beta * data.YC
    # reference temperature is freezing point temperature with 
    # local salinity
    tref = sw.eos80.fp(data.SALT, 1.065) 
    # create a timestamp for saving the heat transport to disk
    savetime = data.time.isel(time=slice(int(len(data.time) / 2), int(len(data.time) / 2) + 1))
    # ceate masks
    dMask = data.maskS * data.Z
    bottom = dMask.min("Z")
    tmpDepth = dMask.where(dMask > dMask.min("Z"), other=0).min("Z")
    nearBottom = dMask.where(dMask > tmpDepth, other=0).min("Z")
    maskShelf = np.ones(np.shape(THTdiv.THTvdiv.isel(time=0)))
    for k in np.arange(0, np.shape(maskShelf)[0]):
        firstwet = np.argmax(THTdiv.THTvdiv[0, k, :, 0].values != 0.)
        maskShelf[k, 0:firstwet+1, :] = 0
    data["maskShelf"] = xr.DataArray(maskShelf, dims=["Z", "YG", "XC"])  
    # prepare temperature and velocity
    T = data.THETA - tref
    V = data.VVEL * data.hFacS
    shelf = data.maskShelf
    return data, T, V, shelf, rho, Cp, f0, beta, Cd, f, savetime, dMask, nearBottom, grid

Now the function to calculate the mean heat transports (averaged over all months).

In [None]:
def calc_hts(data, THTdiv, T, V, shelf, 
             em, em_cycl, em_anti, tm, tm_cycl, tm_anti, nBotDep, grid,
             ML, mld, mldm, mlds, mld_name):
    # calculate mean seasonal cycle and deviation (prime)
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        vbar = V.groupby('time.month').mean("time")
        tbar = T.groupby('time.month').mean("time")
        vprime = V.groupby('time.month') - vbar
        tprime = T.groupby('time.month') - tbar
    # estimate temperature in Ekman layer (averaged over the top 50 m)   
    tauXbar = data.oceTAUX.groupby('time.month').mean("time")
    tEKmean = ((T.sel(Z=slice(0, -50)) * data.drF.sel(Z=slice(0, -50))).sum("Z") 
               / data.drF.sel(Z=slice(0, -50)).sum().values)
    tEKbar = grid.interp(tEKmean.groupby('time.month').mean("time"), "Y", boundary="extend")
    # calulate Ekman velocity from tau
    vEKbar = -tauXbar / (rho * f)
    # compute total, mean and transient
    totalht = rho * Cp * (V * grid.interp(T, "Y", boundary="extend"))
    mht = rho * Cp * (vbar * grid.interp(tbar, "Y", boundary="extend"))
    mhtek = rho * Cp * (grid.interp(vEKbar, "Y", boundary="extend") * grid.interp(tEKbar, "X"))
    tht = rho * Cp * (vprime * grid.interp(tprime, "Y", boundary="extend"))
    # if desired, mask out the regions specify in mld, mldm before integration
    if ML:
        mht = mht * mldm
        tht = tht * mld
    # integrate everythin zonally and average over time
    MHT = (grid.interp(mht, "X") * data.dxV).sum("XG").mean('month')
    MHTek = (mhtek * data.dxV).sum("XG").mean('month')
    THT = (grid.interp(tht, "X") * data.dxV).sum("XG").mean("time")
    # save to disk
    Ms = xr.Dataset({
        "MHT": MHT.expand_dims({"time": savetime}), 
        "MHTek": MHTek.expand_dims({"time": savetime}), 
        "THT": THT.expand_dims({"time": savetime})
        })
    Ms.to_netcdf(eddypath + 'MHTs' + mldname + '.' + str(y1) + '0101_' + str(y2) + '1230.nc')
    # same as above but for the individual seasons
    for s in ["DJF", "MAM", "JJA", "SON"]:
        with dask.config.set(**{'array.slicing.split_large_chunks': True}):
            vbar_season = V.groupby('time.season').mean("time").sel(season=s)
            tbar_season = T.groupby('time.season').mean("time").sel(season=s)
        tauXbar_season = data.oceTAUX.groupby('time.season').mean("time").sel(season=s)
        tEKbar_season = grid.interp(tEKmean.groupby('time.season').mean("time").sel(season=s),
                                    "Y", boundary="extend")
        vEKbar_season = -tauXbar_season / (rho * f)
        mht_season = rho * Cp * (vbar_season * grid.interp(tbar_season, "Y", boundary="extend"))
        mhtek_season = rho * Cp * (grid.interp(vEKbar_season, 
                                               "Y", boundary="extend") * grid.interp(tEKbar_season, "X"))
        # if desired, mask out the regions specify in mld, mldm before integration
        if ML:
            mht_season = mht_season * mlds
        MHT_season = (grid.interp(mht_season, "X") * data.dxV).sum("XG")
        MHTek_season = (mhtek_season * data.dxV).sum("XG")
        THT_season = (grid.interp(tht, "X") 
                      * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        Ms_season = xr.Dataset({
            "MHT": MHT_season.expand_dims({"time": savetime}), 
            "MHTek": MHTek_season.expand_dims({"time": savetime}),
            "THT": THT_season.expand_dims({"time": savetime})
            })
        Ms_season.to_netcdf(eddypath + 'MHTs' + mldname + '.' + s + '.' + str(y1) + '0101_' + str(y2) + '1230.nc')
    # get reference velocity for Watts et al 2016
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        vref = V.sel(Z=nBotDep, method='nearest')
        vrefbar = vref.groupby('time.month').mean("time")
        vrefprime = vref.groupby('time.month') - vrefbar
    # compute THT_ref and THT_div
    thtref = rho * Cp * (vrefprime * grid.interp(tprime, "Y", boundary="extend"))
    thtdiv = THTdiv.THTvdiv
    # if desired, mask out the regions specify in mld, mldm before integration
    if ML:
        thtref = thtref * mld
        thtdiv = thtdiv * mld
    # integrate everything zonally and average over time, multiply with eddymask 
    # before integration to get the CME contribution
    THTeddy = (grid.interp(tht, "X") * em * data.dxV).sum("XG").mean("time")
    THTeddy_cyclones = (grid.interp(tht, "X") * em_cycl * data.dxV).sum("XG").mean("time")
    THTeddy_anticyclones = (grid.interp(tht, "X") * em_anti * data.dxV).sum("XG").mean("time")
    THTtrack = (grid.interp(tht, "X") * tm * data.dxV).sum("XG").mean("time")
    THTtrack_cyclones = (grid.interp(tht, "X") * tm_cycl * data.dxV).sum("XG").mean("time")
    THTtrack_anticyclones = (grid.interp(tht, "X") * tm_anti * data.dxV).sum("XG").mean("time")
    THTref = (grid.interp(thtref, "X") * data.dxV).sum("XG").mean("time")
    THTrefeddy = (grid.interp(thtref, "X") * em * data.dxV).sum("XG").mean("time")
    THTrefeddy_cyclones = (grid.interp(thtref, "X") * em_cycl * data.dxV).sum("XG").mean("time")
    THTrefeddy_anticyclones = (grid.interp(thtref, "X") * em_anti * data.dxV).sum("XG").mean("time")
    THTreftrack = (grid.interp(thtref, "X") * tm * data.dxV).sum("XG").mean("time")
    THTreftrack_cyclones = (grid.interp(thtref, "X") * tm_cycl * data.dxV).sum("XG").mean("time")
    THTreftrack_anticyclones = (grid.interp(thtref, "X") * tm_anti * data.dxV).sum("XG").mean("time")
    THTdiv = (grid.interp(thtdiv, "X")  * data.dxV).sum("XG").mean("time")
    THTdiveddy = (grid.interp(thtdiv, "X") * em * data.dxV).sum("XG").mean("time")
    THTdiveddy_cyclones = (grid.interp(thtdiv, "X") * em_cycl * data.dxV).sum("XG").mean("time")
    THTdiveddy_anticyclones = (grid.interp(thtdiv, "X") * em_anti * data.dxV).sum("XG").mean("time")
    THTdivtrack = (grid.interp(thtdiv, "X") * tm * data.dxV).sum("XG").mean("time")
    THTdivtrack_cyclones = (grid.interp(thtdiv, "X") * tm_cycl * data.dxV).sum("XG").mean("time")
    THTdivtrack_anticyclones = (grid.interp(thtdiv, "X") * tm_anti * data.dxV).sum("XG").mean("time")
    # create dataset and save to disk
    Ts = xr.Dataset({
        "THT": THT.expand_dims({"time": savetime}), 
        "THTeddy": THTeddy.expand_dims({"time": savetime}), 
        "THTeddy_cyclones": THTeddy_cyclones.expand_dims({"time": savetime}), 
        "THTeddy_anticyclones": THTeddy_anticyclones.expand_dims({"time": savetime}),
        "THTtrack": THTtrack.expand_dims({"time": savetime}), 
        "THTtrack_cyclones": THTtrack_cyclones.expand_dims({"time": savetime}), 
        "THTtrack_anticyclones": THTtrack_anticyclones.expand_dims({"time": savetime}),
        "THTref": THTref.expand_dims({"time": savetime}), 
        "THTrefeddy": THTrefeddy.expand_dims({"time": savetime}), 
        "THTrefeddy_cyclones": THTrefeddy_cyclones.expand_dims({"time": savetime}), 
        "THTrefeddy_anticyclones": THTrefeddy_anticyclones.expand_dims({"time": savetime}), 
        "THTreftrack": THTreftrack.expand_dims({"time": savetime}), 
        "THTreftrack_cyclones": THTreftrack_cyclones.expand_dims({"time": savetime}), 
        "THTreftrack_anticyclones": THTreftrack_anticyclones.expand_dims({"time": savetime}), 
        "THTdiv": THTdiv.expand_dims({"time": savetime}), 
        "THTdiveddy": THTdiveddy.expand_dims({"time": savetime}), 
        "THTdiveddy_cyclones": THTdiveddy_cyclones.expand_dims({"time": savetime}), 
        "THTdiveddy_anticyclones": THTdiveddy_anticyclones.expand_dims({"time": savetime}),
        "THTdivtrack": THTdivtrack.expand_dims({"time": savetime}), 
        "THTdivtrack_cyclones": THTdivtrack_cyclones.expand_dims({"time": savetime}), 
        "THTdivtrack_anticyclones": THTdivtrack_anticyclones.expand_dims({"time": savetime})
        })
    Ts.to_netcdf(eddypath + 'THTs' + mldname + '.' + str(y1) + '0101_' + str(y2) + '1230.nc')
    # repear everything for the four individual seasons
    for s in ["DJF", "MAM", "JJA", "SON"]:
        THT_season = (grid.interp(tht, "X") 
                      * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTeddy_season = (grid.interp(tht, "X") 
            * em * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTeddy_cyclones_season = (grid.interp(tht, "X") 
            * em_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTeddy_anticyclones_season = (grid.interp(tht, "X") 
            * em_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTtrack_season = (grid.interp(tht, "X") 
            * tm * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTtrack_cyclones_season = (grid.interp(tht, "X") 
            * tm_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTtrack_anticyclones_season = (grid.interp(tht, "X") 
            * tm_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTref_season = (grid.interp(thtref , "X")
            * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTrefeddy_season = (grid.interp(thtref, "X") 
            * em * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTrefeddy_cyclones_season = (grid.interp(thtref, "X") 
            * em_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTrefeddy_anticyclones_season = (grid.interp(thtref, "X") 
            * em_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTreftrack_season = (grid.interp(thtref, "X") 
            * tm * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTreftrack_cyclones_season = (grid.interp(thtref, "X") 
            * tm_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTreftrack_anticyclones_season = (grid.interp(thtref, "X") 
            * tm_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdiv_season = (grid.interp(thtdiv, "X") 
            * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdiveddy_season = (grid.interp(thtdiv, "X") 
            * em * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdiveddy_cyclones_season = (grid.interp(thtdiv, "X") 
            * em_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdiveddy_anticyclones_season = (grid.interp(thtdiv, "X") 
            * em_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdivtrack_season = (grid.interp(thtdiv, "X") 
            * tm * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdivtrack_cyclones_season = (grid.interp(thtdiv, "X") 
            * tm_cycl * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        THTdivtrack_anticyclones_season = (grid.interp(thtdiv, "X") 
            * tm_anti * data.dxV).sum("XG").groupby("time.season").mean("time").sel(season=s)
        Ts_season = xr.Dataset({
            "THT": THT_season.expand_dims({"time": savetime}), 
            "THTeddy": THTeddy_season.expand_dims({"time": savetime}), 
            "THTeddy_cyclones": THTeddy_cyclones_season.expand_dims({"time": savetime}), 
            "THTeddy_anticyclones": THTeddy_anticyclones_season.expand_dims({"time": savetime}),
            "THTtrack": THTtrack_season.expand_dims({"time": savetime}), 
            "THTtrack_cyclones": THTtrack_cyclones_season.expand_dims({"time": savetime}), 
            "THTtrack_anticyclones": THTtrack_anticyclones_season.expand_dims({"time": savetime}),
            "THTref": THTref_season.expand_dims({"time": savetime}), 
            "THTrefeddy": THTrefeddy_season.expand_dims({"time": savetime}), 
            "THTrefeddy_cyclones": THTrefeddy_cyclones_season.expand_dims({"time": savetime}), 
            "THTrefeddy_anticyclones": THTrefeddy_anticyclones_season.expand_dims({"time": savetime}),
            "THTreftrack": THTreftrack_season.expand_dims({"time": savetime}), 
            "THTreftrack_cyclones": THTreftrack_cyclones_season.expand_dims({"time": savetime}), 
            "THTreftrack_anticyclones": THTreftrack_anticyclones_season.expand_dims({"time": savetime}),
            "THTdiv": THTdiv_season.expand_dims({"time": savetime}), 
            "THTdiveddy": THTdiveddy_season.expand_dims({"time": savetime}), 
            "THTdiveddy_cyclones": THTdiveddy_cyclones_season.expand_dims({"time": savetime}), 
            "THTdiveddy_anticyclones": THTdiveddy_anticyclones_season.expand_dims({"time": savetime}),
            "THTdivtrack": THTdivtrack_season.expand_dims({"time": savetime}), 
            "THTdivtrack_cyclones": THTdivtrack_cyclones_season.expand_dims({"time": savetime}), 
            "THTdivtrack_anticyclones": THTdivtrack_anticyclones_season.expand_dims({"time": savetime})
            })
        Ts_season.to_netcdf(eddypath + 'THTs' + mldname + '.' + s + '.' + str(y1) + '0101_' + str(y2) + '1230.nc')
    return 

And finally we loop over the 10 decades and apply the functions we defined before

In [None]:
a = np.arange(201, 300, 10)
b = np.arange(210, 301, 10)
for yy1, yy2 in zip(a, b):
    y1 = f"{yy1:04}"
    y2 = f"{yy2:04}"
    print(y1,"to",y2)
    # load and prepare data for current decade
    data, THTdiv, em, em_cycl, em_anti, tm, tm_cycl, tm_anti = open_data(path, eddypath, y1, y2)
    data, T, V, shelf, rho, Cp, f0, beta, Cd, f, savetime, depMask, nBotDep, grid = param_and_prepare(data, THTdiv)
    # calculate full depth, full year heat transports
    calc_hts(data, THTdiv, T, V, shelf, em, em_cycl, em_anti, tm, tm_cycl, tm_anti, nBotDep, grid,
             False, None, None, None, "")
    # now include the mixed layer mask
    # MLD mask is the mean winter (JJA) mixed layer 
    nomld = shelf.where(data.Z < grid.interp(-data.MXLDEPTH.groupby("time.season").mean("time").sel(season="JJA"), 
                                             "Y", boundary="extend"), other=0)
    nomldm = nomld
    nomlds = nomld
    calc_hts(data, THTdiv, T, V, shelf, em, em_cycl, em_anti, tm, tm_cycl, tm_anti, nBotDep, grid,
             True, nomld, nomldm, nomlds, "NoML")
    mld = (V / V).isel(time=0).where(data.Z >= grid.interp(-data.MXLDEPTH.groupby("time.season").mean("time").sel(season="JJA"), 
                                              "Y", boundary="extend"), other=0)
    mldm = mld
    mlds = mld
    calc_hts(data, THTdiv, T, V, shelf, em, em_cycl, em_anti, tm, tm_cycl, tm_anti, nBotDep, grid,
             True, mld, mldm, mlds, "ML")