### Notebook to do ensemble-averaged analysis of MPDATA simulations.
This will be generalized to all tracer advection schemes in another notebook

In [1]:
#Packages 
import numpy as np
import xgcm
from xgcm import Grid
import xarray as xr
import xroms
from datetime import datetime
import time # for counting 
import glob
from xhistogram.xarray import histogram
import cmocean.cm as cmo
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.ticker as tick
from matplotlib.dates import DateFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.dates as mdates
from matplotlib.ticker import AutoMinorLocator
from xhistogram.xarray import histogram
from datetime import timedelta
import time

import warnings
warnings.filterwarnings("ignore") #The chaotic option, used to suppress issues with cf_time with xroms 

from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib import pyplot as plt, patches

In [2]:
def open_roms(path):
    ds1 = xroms.open_netcdf(path)
    ds1, grid1 = xroms.roms_dataset(ds1)
    return ds1, grid1

paths = ['/d2/home/dylan/idealized_nummix/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb1_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb2_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb3_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb4_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb5_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb6_avg.nc',
         '/d1/shared/shelfstrat_wind/tadv_ensembles/shelf_mpdata_uwind_zerop1_dt_30_across2x_50d_ensmb7_avg.nc',
        ]

def open_roms(path):
    ds1 = xroms.open_netcdf(path)
    ds1, grid1 = xroms.roms_dataset(ds1)
    ds1 = ds1.isel(ocean_time = slice(0,721))
    return ds1, grid1

ds = []
grid = []
for i in range(len(paths)):
    ds1, grid1 = open_roms(paths[i])
    ds.append(ds1)
    grid.append(grid1)
    # print('iter complete')

In [3]:
def mixing_vint(ds,grid,xislice,etaslice):
    
    mnum = (ds.dye_03*ds.dV).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mnum.attrs = ''
    mnum.name = 'mnum'
    
    Akr = ds.AKr
    AKr_rho = grid.interp(Akr, 'Z', boundary = 'extend')
    mphy = (AKr_rho*ds.dV).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mphy.attrs = ''
    mphy.name = 'mphy'
    
    mt = ds.dye_03+AKr_rho
    mtot = (mt*ds.dV).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mtot.attrs = ''
    mtot.name = 'mtot'
    
    ds_mix = xr.merge([mnum, mphy, mtot])
    return ds_mix

def mixing_pervol(ds, grid, etaslice, xislice):
    '''
Computes volume-integrated physical and numerical mixing for ROMS model output.
See Schlichting et al. (2023) JAMES for more information.
    '''
    mnum = ds.dye_03.isel(eta_rho = etaslice, xi_rho = xislice) #Numerical salt mixing
    AKr = grid.interp(ds.AKr, 'Z') #Destruction of salt variance

    mphy = AKr.isel(eta_rho = etaslice, xi_rho = xislice)
    dV = ds.dV.isel(eta_rho = etaslice, xi_rho = xislice)
    V = dV.sum(['eta_rho','xi_rho','s_rho'])
    
    #Volume integrate, then name for concat
    mnum_int = (mnum*dV).sum(['eta_rho', 'xi_rho', 's_rho'])
    mnum_int.attrs = [] # Remove grid so we can save to netcdf 
    mphy_int = (mphy*dV).sum(['eta_rho', 'xi_rho', 's_rho'])
    mphy_int.attrs = []
    mtot = mnum+mphy 
    mtot_int = (mtot*dV).sum(['eta_rho','xi_rho','s_rho'])
    mtot_int.attrs = ''

    mnum_pervol = (mnum_int/V)
    mnum_pervol.name = 'mnum'
    mphy_pervol = (mphy_int/V)
    mphy_pervol.name = 'mphy'
    mtot_pervol = (mtot_int/V)
    mtot_pervol.name = 'mtot'
    
    ds_mix = xr.merge([mnum_pervol, mphy_pervol, mtot_pervol])
    return ds_mix

def mixing_vint_top1m(ds,grid,xislice,etaslice):
    
    mnum = (ds.dye_03*ds.dV).where(ds.z_rho>-1).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mnum.attrs = ''
    mnum.name = 'mnum'
    
    Akr = ds.AKr
    AKr_rho = grid.interp(Akr, 'Z', boundary = 'extend')
    mphy = (AKr_rho*ds.dV).where(ds.z_rho>-1).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mphy.attrs = ''
    mphy.name = 'mphy'
    
    mt = ds.dye_03+AKr_rho
    mtot = (mt*ds.dV).where(ds.z_rho>-1).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho','eta_rho','xi_rho'])
    mtot.attrs = ''
    mtot.name = 'mtot'
    
    ds_mix_1m = xr.merge([mnum, mphy, mtot])
    return ds_mix_1m

def energy_vint(ds,grid,etaslice,xislice):
    '''
Modifies volume-integrated eddy, mean, and total kinetic energy modified from Hetland (2017) JPO.
Notes:
------
EKE = 1/2(uprime^2 + v^2). 
MKE = 1/2(ubar^2+vbar^2)
TKE = 1/2(u^2+v^2)
u = ubar+uprime, ubar = 1/L int_0^L u dx, i.e., alongshore mean
v = vbar+vprime
Velocities interpolated to their respective rho points
    '''
    u = xroms.to_rho(ds.u, grid)
    urho = u.isel(eta_rho = etaslice, xi_rho = xislice) 
    v = xroms.to_rho(ds.v, grid)
    vrho = v.isel(eta_rho = etaslice, xi_rho = xislice)
    
    ubar = urho.mean('xi_rho')
    uprime = urho-ubar
    
    vbar = vrho.mean('xi_rho')
    vprime = vrho-vbar
    
    dV = ds.dV.isel(eta_rho = etaslice, xi_rho = xislice)
    #Mean kinetic energy
    mke = 0.5*(ubar**2+vbar**2)
    mke_int = (mke*dV).sum(['eta_rho', 'xi_rho', 's_rho'])
    mke_initial = (mke*dV).sum(['eta_rho', 'xi_rho', 's_rho'])[0] # Initial value for normalization
    mke_int.attrs = ''
    mke_int.name = 'mke'
    mke_initial.attrs = ''
    mke_initial.name = 'mke_initial'

    #Eddy kinetic energy
    eke = 0.5*(uprime**2 + vprime**2)
    eke_int = (eke*dV).sum(['eta_rho', 'xi_rho', 's_rho'])
    eke_int.attrs = ''
    eke_int.name = 'eke'

    #Total kinetic energy 
    tke = 0.5*(urho**2+vrho**2)
    tke_int = (tke*dV).sum(['eta_rho', 'xi_rho', 's_rho'])
    tke_int.attrs = ''
    tke_int.name = 'tke'     
    
    ds_energy = xr.merge([eke_int, mke_int, tke_int])
    return ds_energy

def sprime2_whole(ds, grid, xislice, etaslice):
    ''' 
Returns volume-averaged salinity variance. 
Inputs:
----
ds: DataArray
salt: DataArray
Outputs:
----
svar: total salinity variance
    '''
    salt = ds.salt.isel(eta_rho = etaslice, xi_rho = xislice)
    dV = ds.dV.isel(eta_rho = etaslice, xi_rho = xislice)
    V = dV.sum(['s_rho', 'xi_rho', 'eta_rho'])
    
    sbar = (1/V)*((salt*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    sp_tot = (1/V)*((((salt-sbar)**2)*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    
    return sp_tot


def sprime2_top1m(ds, grid, xislice, etaslice):
    ''' 
Returns volume-averaged salinity variance. 
Inputs:
----
ds: DataArray
salt: DataArray
Outputs:
----
svar: total salinity variance
    '''
    salt = ds.salt.where(ds.z_rho>-1).isel(eta_rho = etaslice, xi_rho = xislice)
    dV = ds.dV.where(ds.z_rho>-1).isel(eta_rho = etaslice, xi_rho = xislice)
    V = dV.sum(['s_rho', 'xi_rho', 'eta_rho'])
    
    sbar = (1/V)*((salt*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    sp_1m = (1/V)*((((salt-sbar)**2)*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    
    return sp_1m

def sprime2_whole(ds, grid, xislice, etaslice):
    ''' 
Returns the volume-averaged terms for the decomposition of salinity variance: 
total, vertical, and horizontal variance. See Li et al. (2018) JPO for details.
Inputs:
----
ds: DataArray
salt: DataArray
Outputs:
----
svar: total salinity variance
svert: vertical salinity variance
shorz: horizontal salinity variance
    '''
    salt = ds.salt.isel(eta_rho = etaslice, xi_rho = xislice)
    dV = ds.dV.isel(eta_rho = etaslice, xi_rho = xislice)
    V = dV.sum(['s_rho', 'xi_rho', 'eta_rho'])
    
    sbar = (1/V)*((salt*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    sp_tot = (1/V)*((((salt-sbar)**2)*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    
    # Now compute the local vertical salinity variance
    dz = ds.dz.isel(eta_rho = etaslice, xi_rho = xislice)
    Z = dz.sum(['s_rho'])
    sbar_z = (1/Z)*((salt*dz).sum(['s_rho']))
    sp_vert = (1/V)*((((salt-sbar_z)**2)*dV).sum(['s_rho', 'xi_rho', 'eta_rho']))
    
    sp_horz = (sp_tot-sp_vert)
    
    sp_tot.attrs = ''
    sp_tot.name = 'svar_tot'
    sp_vert.attrs = ''
    sp_vert.name = 'svar_vert'
    sp_horz.attrs = ''
    sp_horz.name = 'svar_horz'
    
    svar_da = xr.merge([sp_tot, sp_vert, sp_horz])
    
    return svar_da

def rho_linear_eos(ds):
    '''
Calculate density based on linear equation of state described in Hetland (2017)
    '''
    rho = 1027*((1+7.6*(10**-4*(ds.salt-35)))-(1.7*10**-4*(ds.temp-25)))
    return rho

def calc_ape(ds, etaslice, xislice):
    '''
Calculate APE from lateral density gradients and SSH. See Eqs. B6 and B7
of Hetland (2017)
Inputs: 
-------
ds - xarray dataset
etaslice - across-shore slice (i.e. slice(1,100)) for init. stratified region. 
xislice - alongshore slice (i.e., slice(1,-1)) to remove periodic BCs

Outputs:
--------
ape: xarray dataarray with energy stored in lateral density gradients
    '''
    rho0 = 1025 #background density determined from input file
    rho = rho_linear_eos(ds)
    
    g = 9.81 
    b = (g*(rho0-rho))*(1/rho0) 
    
    rho_temp = (1027*(1-(1.7*10**-4*(ds.temp[0]-25)))).values #Temperature based density @ first time
    rho_init = rho[0].values #Initial density 
    bref = (g*(1025-(rho_temp)))*(1/1025) #Reference buoyancy @ first time. Function of x,y,z

    bp = b-bref
    
    #Lateral density gradient APE
    ape_r = -((1025*bp*ds.z_rho*ds.dV).isel(eta_rho = etaslice, xi_rho = xislice).sum(['s_rho', 'eta_rho', 'xi_rho']))
    ape_r.attrs = ''
    ape_r.name = 'ape_r'
    
    return ape_r

In [4]:
# Run the functions
# -----------------
xislice = slice(1,-1)
etaslice = slice(1,193) 
# etaslice = slice(1,100)

# dsm = []
# for i in range(len(paths)):
#     ds_mix = mixing_vint(ds[i],grid[i],xislice,etaslice)
#     dsm.append(ds_mix)

# dsm = []
# for i in range(len(paths)):
#     ds_mix = mixing_pervol(ds[i],grid[i],xislice,etaslice)
#     dsm.append(ds_mix)
    
# dsm_1m = []
# for i in range(len(paths)):
#     ds_mix_1m = mixing_vint_top1m(ds[i],grid[i],xislice,etaslice)
#     dsm_1m.append(ds_mix_1m)
    
dse = []
for i in range(len(paths)):
    ds_energy = energy_vint(ds[i],grid[i],etaslice,xislice)
    dse.append(ds_energy)

# svar = []
# for i in range(len(paths)):
#     svar_da = sprime2_whole(ds[i], grid[i], xislice, etaslice)
#     svar.append(svar_da)
    
# svar = []
# for i in range(len(paths)):
#     sp_tot = sprime2_whole(ds[i], grid[i], xislice, etaslice)
#     svar.append(sp_tot)

# svar = []
# for i in range(len(paths)):
#     sp_tot = sprime2_whole(ds[i], grid[i], xislice, etaslice)
#     svar.append(sp_tot)
    
# svar_1m = []
# for i in range(len(paths)):
#     sp_1m = sprime2_top1m(ds[i], grid[i], xislice, etaslice)
#     svar_1m.append(sp_1m)

# ds_ape = []
# for i in range(len(paths)):
#     start_time = time.time()   
#     ape = calc_ape(ds[i], etaslice, xislice)
#     ds_ape.append(ape)
#     print(time.time() - start_time)

In [5]:
n = 0 # Start at zero
# n = 5

# for i in range(len(paths)):
#     # svar[i].attrs = ''
#     # svar[i].name = 'svar'
#     p = 'svar_vavg_whole_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     svar[i].to_netcdf(p, mode = 'w')
    
# for i in range(len(paths)):
#     svar_1m[i].attrs = ''
#     svar_1m[i].name = 'svar'
#     p = 'svar_vavg_top1m_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     svar_1m[i].to_netcdf(p, mode = 'w')

for i in range(len(paths)):
    p = 'outputs/ene_vint_whole_mpdata_'+'ensemble_'+str(i+n)+'.nc'
    dse[i].to_netcdf(p, mode = 'w')
    
# for i in range(len(paths)):
#     p = 'mix_vint_whole_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     dsm[i].to_netcdf(p, mode = 'w')

# for i in range(len(paths)):
#     p = 'mix_pervol_whole_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     dsm[i].to_netcdf(p, mode = 'w')
    
    
# for i in range(len(paths)):
#     p = 'mix_vint_top1m_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     dsm_1m[i].to_netcdf(p, mode = 'w')

# for i in range(len(paths)):
#     start_time = time.time()   
#     p = 'outputs/ape_vint_whole_mpdata_'+'ensemble_'+str(i+n)+'.nc'
#     ds_ape[i].to_netcdf(p, mode = 'w')
#     print(time.time() - start_time)