### LOG

2/11/21
- Fixed issue with vertical grid (it was because I squeeze the dimensions)
- Now have the capacity to calculation hldot for both tendencies and fluxes, I think this unifies the approach
- The question now is hoow to wrap it up, e.g. within a class object, and then perform the wmt calculation
***
1/11/21 
- Some success in defining a generic approach to calculating hldot, which should be applicable to 3D and 2D fluxes and tendencies. 
- Wrappers on top of this will allow specification of attributes based on knowledge of diagnostic names
- Further wrapper can be used for calculating density and associated fluxes/tendencies
- Testing datasets and protocols also starting to be established
- Encountering issues relating to the pseudo grid for 2D surface fluxes: for some reason it gets upset at the size of z_l, but I can't see any difference to how I have done this previously



In [1]:
import xarray as xr
from xhistogram.xarray import histogram
import numpy as np
from xgcm import Grid

In [2]:
# Subsidiary functions
def dict_retain_keys(dictionary,retain):
    dictionary_new = {}
    for key in dictionary.keys():
        if key in retain:
            dictionary_new[key]=dictionary[key]
    return dictionary_new

### Grid functions

In [88]:
# Grid manipulation functions
def expand_surface_to_3D(surfaceflux,z):
    return surfaceflux.expand_dims({'z_i':z}).where(z==z[0],0)

def expand_grid_in_z(grid):
    gridz = grid.copy()
    # If the dataset does not have a z-dimension, create a "dummy" variable 
    # with which to define a vertical for the upper cell.
    if ('z_l' not in gridz.dims) & ('z_i' not in gridz.dims):
        gridz['z_l'] = xr.DataArray(np.array([2.5]),dims='z_l')
        gridz['z_i'] = xr.DataArray(np.array([0.,5.]),dims='z_i')
    return gridz

def get_xgcm_grid(grid):
    # Build an xgcm grid object

    grid['dzt'] = grid['z_l'].copy(data=grid['z_i'].diff('z_i'))

    # Fill in nans with zeros
    grid['dxt'] = grid['dxt'].fillna(0.)
    grid['dyt'] = grid['dyt'].fillna(0.)
    grid['dzt'] = grid['dzt'].fillna(0.)
    grid['areacello'] = grid['areacello'].fillna(0.)
    grid['volcello'] = (grid['areacello']*grid['dzt']).fillna(0.)
    metrics = {
        ('X',): ['dxt','dxCu','dxCv'], # X distances
        ('Y',): ['dyt','dyCu','dyCv'], # Y distances
        ('Z',): ['dzt'], # Z distances
        ('X', 'Y'): ['areacello'], # Areas
        ('X', 'Y', 'Z'): ['volcello'], # Volumes
    }
    coords={'X': {'center': 'xh', 'right': 'xq'},
            'Y': {'center': 'yh', 'right': 'yq'},
            'Z': {'center': 'z_l', 'outer': 'z_i'} }
    
    return Grid(grid, coords=coords, metrics=metrics, periodic=['X'])

### Core functions

In [89]:
# Calculation of hldot (cell-depth integral of scalar tendency)
# provided various forms of input (fluxes, tendencies, intensive, extensive)
def hldot_from_Jl(grid,Jl):
    dim = 'Z' # Otherwise _infer_ dimension based on position of value?
    ldot = grid.derivative(Jl,dim) # Would obvs need to do this for all dimensions
    hldot = ldot*grid.get_metric(ldot,'Z')
    return hldot

def hldot_from_ldot(grid,ldot):
    hldot = grid.get_metric(ldot,'Z')*ldot
    return hldot

def hldot_from_JL(grid,JL):
    dim='Z'
    hldot = grid.derivative(JL,dim)
    return hldot
        
def hldot_from_Ldot(xgrid,Ldot):
    hldot=Ldot
    return hldot

def Jl_from_massflux(massflux,scalar_in_mass,scalar_in_surface):
    return massflux*(scalar_in_surface-scalar_in_mass)

In [90]:
# These functions could exist as part of a class, of which the xgcm grid could be a part (in which case grid would not need to be called in each function)

def _calc_hldot(xgrid,process,intensive_or_extensive=None,flux_or_tendency=None):
    if (intensive_or_extensive=="intensive") & (flux_or_tendency=="flux"):
        hldot = hldot_from_Jl(xgrid,process)
    elif (intensive_or_extensive=="intensive") & (flux_or_tendency=="tendency"):
        hldot = hldot_from_ldot(xgrid,process)
    elif (intensive_or_extensive=="extensive") & (flux_or_tendency=="flux"):
        hldot = hldot_from_JL(xgrid,process)
    elif (intensive_or_extensive=="extensive") & (flux_or_tendency=="tendency"):
        hldot = hldot_from_Ldot(xgrid,process)
        
    return hldot

def _calc_hldot_massflux(xgrid,process,scalar_in_massflux,scalar_at_boundary):
    # Special case
    Jl = Jl_from_massflux(process,scalar_in_massflux,scalar_at_boundary)
    hldot = hldot_from_Jl(xgrid,Jl)
    
    return hldot
        
def calc_hldot(ds,processname,xgrid):
    # Determine is ds is a dataset or a datarray
    process = ds[processname]
    
    ### Checking if it has necessary attributes ###
    attrs = ds[processname].attrs
    # Scrap attributes except those desired
    desired = ["flux_or_tendency","intensive_or_extensive","massflux","scalar_in_massflux","scalar_at_boundary"]
    attrs = dict_retain_keys(attrs,desired)
    if len(attrs)==0:
        print(processname+" has no WMT-relevant attributes so is not being considered")
        return
    
    ### Checking whether it is a 2D surface flux ###
    # If it is, place it on the very upper interface of the grid, and set zero elsewhere
    if "z_i" not in process.dims:
        z = xr.DataArray(np.array([0.,5.]),dims='z_i')
        process = expand_surface_to_3D(process,z)
    
    # Special case that the process is associated with a boundary mass flux 
    # (commonly freshwater flux at the ocean surface)
    if "massflux" in attrs:
        # Requires that other attributes also be specified (could check for this higher up)
        if (attrs["scalar_in_massflux"] is None) or (attrs["scalar_at_boundary"] is None):
            print("To evaluate WMT due to boundary mass fluxes requires that the scalar"+
                  "concentration in the mass flux and the scalar concentration at the"+
                  "exposed boundary be specified")
        else:
            scalar_in_massflux = attrs["scalar_in_massflux"]
            if isinstance(attrs["scalar_at_boundary"],str):
                scalar_at_boundary = ds[attrs["scalar_at_boundary"]]
            else:
                scalar_at_boundary = attrs["scalar_at_boundary"]
            hldot = _calc_hldot_massflux(xgrid,process,scalar_in_massflux,scalar_at_boundary)
    else:
        hldot = _calc_hldot(xgrid,process,**attrs)
        
    return hldot

In [91]:
attrs_dict = {
    # Surface fluxes
    "hfds":{'associated_scalar':'tos','flux_or_tendency':'flux','intensive_or_extensive':'extensive'},
              "sfdsi":{'associated_scalar':'sos','flux_or_tendency':'flux','intensive_or_extensive':'extensive'},
              "wfo":{'associated_scalar':'sos','flux_or_tendency':'flux','intensive_or_extensive':'extensive',
                     'massflux':True,'scalar_in_massflux':0,'scalar_at_boundary':'sos'},
    # MOM6 heat tendency variables
             "boundary_forcing_heat_tendency":{'associated_scalar':'temp','flux_or_tendency':'tendency','intensive_or_extensive':'extensive'},
             "frazil_heat_tendency":{'associated_scalar':'temp','flux_or_tendency':'tendency','intensive_or_extensive':'extensive'},
             "internal_heat_heat_tendency":{'associated_scalar':'temp','flux_or_tendency':'tendency','intensive_or_extensive':'extensive'},
             "opottempdiff":{'associated_scalar':'temp','flux_or_tendency':'tendency','intensive_or_extensive':'extensive'}
}

### 2D fluxes

In [95]:
ds = xr.open_mfdataset('../data/raw/testdata_OM4p25/ocean_monthly.201801-201812.*.nc')
grid = xr.open_dataset('../data/raw/testdata_OM4p25/ocean_monthly.static.nc')
gridz = expand_grid_in_z(grid)
xgrid = get_xgcm_grid(gridz)

In [97]:
hldot = xr.Dataset()
for var in ds.data_vars:
    if var in attrs_dict.keys():
        ds[var] = ds[var].assign_attrs(attrs_dict[var])
    hldot[var] = calc_hldot(ds,var,xgrid)

average_DT has no WMT-relevant attributes so is not being considered
average_T1 has no WMT-relevant attributes so is not being considered
average_T2 has no WMT-relevant attributes so is not being considered


ValueError: indexes along dimension 'z_i' are not equal

In [41]:
ds

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 96 B 96 B Shape (12,) (12,) Count 20 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",12  1,

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 96 B 96 B Shape (12,) (12,) Count 20 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",12  1,

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray
"Array Chunk Bytes 96 B 96 B Shape (12,) (12,) Count 20 Tasks 1 Chunks Type datetime64[ns] numpy.ndarray",12  1,

Unnamed: 0,Array,Chunk
Bytes,96 B,96 B
Shape,"(12,)","(12,)"
Count,20 Tasks,1 Chunks
Type,datetime64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 71.19 MiB 71.19 MiB Shape (12, 1080, 1440) (12, 1080, 1440) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,192 B,192 B
Shape,"(12, 2)","(12, 2)"
Count,20 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray
"Array Chunk Bytes 192 B 192 B Shape (12, 2) (12, 2) Count 20 Tasks 1 Chunks Type timedelta64[ns] numpy.ndarray",2  12,

Unnamed: 0,Array,Chunk
Bytes,192 B,192 B
Shape,"(12, 2)","(12, 2)"
Count,20 Tasks,1 Chunks
Type,timedelta64[ns],numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 71.19 MiB 71.19 MiB Shape (12, 1080, 1440) (12, 1080, 1440) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 71.19 MiB 71.19 MiB Shape (12, 1080, 1440) (12, 1080, 1440) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 71.19 MiB 71.19 MiB Shape (12, 1080, 1440) (12, 1080, 1440) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 71.19 MiB 71.19 MiB Shape (12, 1080, 1440) (12, 1080, 1440) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,71.19 MiB,71.19 MiB
Shape,"(12, 1080, 1440)","(12, 1080, 1440)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [9]:
hldot

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,10 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 142.38 MiB 142.38 MiB Shape (1, 12, 1080, 1440) (1, 12, 1080, 1440) Count 10 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,10 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,10 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 142.38 MiB 142.38 MiB Shape (1, 12, 1080, 1440) (1, 12, 1080, 1440) Count 10 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,10 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,16 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 142.38 MiB 142.38 MiB Shape (1, 12, 1080, 1440) (1, 12, 1080, 1440) Count 16 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  1440  1080  12,

Unnamed: 0,Array,Chunk
Bytes,142.38 MiB,142.38 MiB
Shape,"(1, 12, 1080, 1440)","(1, 12, 1080, 1440)"
Count,16 Tasks,1 Chunks
Type,float64,numpy.ndarray


### 3D tendencies

In [18]:
ds = xr.open_mfdataset('../data/raw/testdata_Baltic/19000101.ocean_z_heat_1900_01.nc')
grid = xr.open_dataset('../data/raw/testdata_Baltic/19000101.ocean_static.nc')
xgrid = get_xgcm_grid(ds,grid)

In [19]:
hldotds = xr.Dataset()
for var in ds.data_vars:
    if var in attrs_dict.keys():
        ds[var] = ds[var].assign_attrs(attrs_dict[var])
    hldotds[var] = calc_hldot(ds,var,xgrid)

temp has no WMT-relevant attributes so is not being considered
opottemptend has no WMT-relevant attributes so is not being considered
opottemppmdiff has no WMT-relevant attributes so is not being considered
T_advection_xy has no WMT-relevant attributes so is not being considered
Th_tendency_vert_remap has no WMT-relevant attributes so is not being considered
average_T1 has no WMT-relevant attributes so is not being considered
average_T2 has no WMT-relevant attributes so is not being considered
average_DT has no WMT-relevant attributes so is not being considered
time_bnds has no WMT-relevant attributes so is not being considered


In [20]:
hldotds

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 97.91 MiB 97.91 MiB Shape (36, 1, 35, 105, 97) (36, 1, 35, 105, 97) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",1  36  97  105  35,

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 97.91 MiB 97.91 MiB Shape (36, 1, 35, 105, 97) (36, 1, 35, 105, 97) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",1  36  97  105  35,

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 97.91 MiB 97.91 MiB Shape (36, 1, 35, 105, 97) (36, 1, 35, 105, 97) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",1  36  97  105  35,

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 97.91 MiB 97.91 MiB Shape (36, 1, 35, 105, 97) (36, 1, 35, 105, 97) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",1  36  97  105  35,

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [14]:
hldot = hldotds['boundary_forcing_heat_tendency']

In [15]:
hldot

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 97.91 MiB 97.91 MiB Shape (36, 1, 35, 105, 97) (36, 1, 35, 105, 97) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",1  36  97  105  35,

Unnamed: 0,Array,Chunk
Bytes,97.91 MiB,97.91 MiB
Shape,"(36, 1, 35, 105, 97)","(36, 1, 35, 105, 97)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [16]:
def calc_G(hldot,l,area,bins):
    nanmask = ~np.isnan(hldot)
    return histogram(l.where(nanmask),bins=bins,dim=['xh','yh','z_l'])/np.diff(bins)

In [None]:
calc_G

In [32]:
da = xr.DataArray(np.ones(shape=(10,10)),dims=['x','y'],coords={'x':np.arange(10),'y':np.arange(10)})
da.name = 'foo'
ds = da.to_dataset()

ds['z_i'] = xr.DataArray(np.array([0,5]),dims=['z_i'])
ds['z_l'] = xr.DataArray(np.array([2.5]),dims=['z_l'])
coords={'X': {'center': 'xh'},
        'Y': {'center': 'yh'},
        'Z': {'center':'z_l','outer':'z_i'}}

xgrid = Grid(ds,coords=coords)

ds['foo'] = ds['foo'].expand_dims({'z_i':ds['z_i']})
ds['foo'] = ds['foo'].where(ds['z_i']==ds['z_i'][0],0)

xgrid.interp(ds['foo'],'Z')

In [29]:
ds['foo']