In [1]:
import numpy as np
import xarray
import xesmf

# Manually adding this to the path to 
# avoid problems installing on analysis. 
# https://github.com/raphaeldussin/HCtFlood
# code courtesy of Raphael Dussin - https://github.com/raphaeldussin/HCtFlood/blob/master/HCtFlood/kara.py

from numba import njit
import numpy as np
from dask.base import tokenize
import dask.array as dsa
import xarray as xr


def flood_kara(data, xdim='lon', ydim='lat', zdim='z', tdim='time',
               spval=1e+15):
    """Apply extrapolation onto land from Kara algo.
    Arguments:
        data {xarray.DataArray} -- input data
    Keyword Arguments:
        xdim {str} -- name of x dimension (default: {'lon'})
        ydim {str} -- name of y dimension (default: {'lat'})
        zdim {str} -- name of z dimension (default: {'z'})
        tdim {str} -- name of time dimension (default: {'time'})
        spval {float} -- missing value (default: {1e+15})
    Returns:
        xarray.DataArray -- result of the extrapolation
    """
    # check for input data shape
    if tdim not in data.dims:
        data = data.expand_dims(dim=tdim)
    if zdim not in data.dims:
        data = data.expand_dims(dim=zdim)

    nrec = len(data[tdim])
    nlev = len(data[zdim])
    ny = len(data[ydim])
    nx = len(data[xdim])
    shape = (nrec, nlev, ny, nx)
    chunks = (1, 1, ny, nx)

    def compute_chunk(zlev, trec):
        data_slice = data.isel({tdim: trec, zdim: zlev})
        return flood_kara_xr(data_slice, spval=spval)[None, None]

    name = str(data.name) + '-' + tokenize(data.name, shape)
    dsk = {(name, rec, lev, 0, 0,): (compute_chunk, lev, rec)
           for lev in range(nlev)
           for rec in range(nrec)}

    out = dsa.Array(dsk, name, chunks,
                    dtype=data.dtype, shape=shape)

    xout = xr.DataArray(data=out, name=str(data.name),
                        coords={tdim: data[tdim],
                                zdim: data[zdim],
                                ydim: data[ydim],
                                xdim: data[xdim]},
                        dims=(tdim, zdim, ydim, xdim))

    # rechunk the result
    xout = xout.chunk({tdim: 1, zdim: nlev, ydim: ny, xdim: nx})

    return xout

def flood_kara_xr(dataarray, spval=1e+15):
    """Apply flood_kara on a xarray.dataarray
    Arguments:
        dataarray {xarray.DataArray} -- input 2d data array
    Keyword Arguments:
        spval {float} -- missing value (default: {1e+15})
    Returns:
        numpy.ndarray -- field after extrapolation
    """

    masked_array = dataarray.squeeze().to_masked_array()
    out = flood_kara_ma(masked_array, spval=spval)
    return out

def flood_kara_ma(masked_array, spval=1e+15):
    """Apply flood_kara on a numpy masked array
    Arguments:
        masked_array {np.ma.masked_array} -- array to extrapolate
    Keyword Arguments:
        spval {float} -- missing value (default: {1e+15})
    Returns:
        out -- field after extrapolation
    """

    field = masked_array.data

    if np.isnan(field).all():
        # all the values are NaN, can't do anything
        out = field.copy()
    else:
        # proceed with extrapolation
        field[np.isnan(field)] = spval
        mask = np.ones(field.shape)
        mask[masked_array.mask] = 0
        out = flood_kara_raw(field, mask)
    return out


def flood_kara_raw(field, mask, nmax=1000):
    """Extrapolate land values onto land using the kara method
    (https://doi.org/10.1175/JPO2984.1)
    Arguments:
        field {np.ndarray} -- field to extrapolate
        mask {np.ndarray} -- land/sea binary mask (0/1)
    Keyword Arguments:
        nmax {int} -- max number of iteration (default: {1000})
    Returns:
        drowned -- field after extrapolation
    """

    ny, nx = field.shape
    nxy = nx * ny
    # create fields with halos
    ztmp = np.zeros((ny+2, nx+2))
    zmask = np.zeros((ny+2, nx+2))
    # init the values
    ztmp[1:-1, 1:-1] = field.copy()
    zmask[1:-1, 1:-1] = mask.copy()

    ztmp_new = ztmp.copy()
    zmask_new = zmask.copy()
    #
    nt = 0
    while (zmask[1:-1, 1:-1].sum() < nxy) and (nt < nmax):
        for jj in np.arange(1, ny+1):
            for ji in np.arange(1, nx+1):

                # compute once those indexes
                jjm1 = jj-1
                jjp1 = jj+1
                jim1 = ji-1
                jip1 = ji+1

                if (zmask[jj, ji] == 0):
                    c6 = 1 * zmask[jjm1, jim1]
                    c7 = 2 * zmask[jjm1, ji]
                    c8 = 1 * zmask[jjm1, jip1]

                    c4 = 2 * zmask[jj, jim1]
                    c5 = 2 * zmask[jj, jip1]

                    c1 = 1 * zmask[jjp1, jim1]
                    c2 = 2 * zmask[jjp1, ji]
                    c3 = 1 * zmask[jjp1, jip1]

                    ctot = c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8

                    if (ctot >= 3):
                        # compute the new value for this point
                        zval = (c6 * ztmp[jjm1, jim1] +
                                c7 * ztmp[jjm1, ji] +
                                c8 * ztmp[jjm1, jip1] +
                                c4 * ztmp[jj, jim1] +
                                c5 * ztmp[jj, jip1] +
                                c1 * ztmp[jjp1, jim1] +
                                c2 * ztmp[jjp1, ji] +
                                c3 * ztmp[jjp1, jip1]) / ctot

                        # update value in field array
                        ztmp_new[jj, ji] = zval
                        # set the mask to sea
                        zmask_new[jj, ji] = 1
        nt += 1
        ztmp = ztmp_new.copy()
        zmask = zmask_new.copy()

        if nt == nmax:
            raise ValueError('number of iterations exceeded maximum, '
                             'try increasing nmax')

    drowned = ztmp[1:-1, 1:-1]

    return drowned

def vgrid_to_interfaces(vgrid, max_depth=6500.0):
    """Convert layer thicknesses to interface depths.
    Args:
        vgrid: array of layer thicknesses.
        max_depth: maximum depth of the model. The lowest interface depth will be set to this.
    Returns:
        Array of interface depths.     
    """
    if isinstance(vgrid, xarray.DataArray):
        vgrid = vgrid.data
    zi = np.concatenate([[0], np.cumsum(vgrid)])
    zi[-1] = max_depth
    return zi


def vgrid_to_layers(vgrid, max_depth=6500.0):
    """Convert layer thicknesses to depths of layer midpoints.
    Args:
        vgrid: array of layer thicknesses.
        max_depth: maximum depth of the model. The lowest interface depth will be set to this.
    Returns:
        Array of layer depths.     
    """
    if isinstance(vgrid, xarray.DataArray):
        vgrid = vgrid.data
    ints = vgrid_to_interfaces(vgrid, max_depth=max_depth)
    z = (ints + np.roll(ints, shift=1)) / 2
    layers = z[1:]
    return layers


def rotate_uv(u, v, angle, in_degrees=False):
    """Rotate velocities from earth-relative to model-relative.
    Args:
        u: west-east component of velocity.
        v: south-north component of velocity.
        angle: angle of rotation from true north to model north.
        in_degrees (bool): typically angle is in radians, but set this to True if it is in degrees.
    Returns:
        Model-relative west-east and south-north components of velocity.
    """
    if in_degrees:
        angle = np.radians(angle)
    urot = np.cos(angle) * u + np.sin(angle) * v
    vrot = -np.sin(angle) * u + np.cos(angle) * v
    return urot, vrot


def interpolate_flood_tracers(ds, target_grid):
    """Interpolate and flood data at tracer points (temperature, salinity, free surface).
    Args:
        ds (xarray.Dataset): Dataset with variables temp, salt, and ssh.
        target_grid (xarray.Dataset): Model supergrid with variables x, y and coords nxp, nyp.
    Returns:
        xarray.Dataset: Dataset flooded and interpolated to MOM tracer grid. 
    """
    # Flood temperature and salinity over land.
    flooded = xarray.merge((
        flood_kara(ds[v], zdim='zl') for v in ['temp', 'salt']
    ))
    
    # Flood ssh separately to avoid extra z=0
    flooded['ssh'] = flood_kara(ds['ssh']).isel(z=0).drop('z')
    
    # Interpolate
    target_points = (
        target_grid
        [['x', 'y']]
        .isel(nxp=slice(1, None, 2), nyp=slice(1, None, 2))
        .rename({'y': 'lat', 'x': 'lon', 'nxp': 'xh', 'nyp': 'yh'})
    )
    soda_to_mom = xesmf.Regridder(
        flooded, 
        target_points, 
        method='bilinear', 
        filename='regrid_soda_tracers.nc',
        reuse_weights=False,
        periodic=True
    )
    interped = soda_to_mom(flooded).drop(['lon', 'lat'])
    return interped


def interpolate_flood_velocity(ds, target_grid):
    """Interpolate and flood velocity data.
    Args:
        ds (xarray.Dataset): Dataset with variables u and v.
        target_grid (xarray.Dataset): Model supergrid with variables x, y and coords nxp, nyp.
    Returns:
        xarray.Dataset: Dataset flooded and interpolated to MOM velocity grid. 
    """
    # Flood over land.
    flooded = xarray.merge((
        flood_kara(ds[v], zdim='zl') for v in ['u', 'v']
    ))

    # Interpolate u and v onto supergrid to make rotation possible
    target_uv = (
        target_grid
        [['x', 'y']]
        .rename({'y': 'lat', 'x': 'lon'})
    )
    soda_to_uv = xesmf.Regridder(
        ds, target_uv, 
        filename='regrid_soda_uv.nc',
        method='nearest_s2d',
        reuse_weights=False,
        periodic=True
    )
    interped_uv = soda_to_uv(flooded[['u', 'v']]).drop(['lon', 'lat'])
    urot, vrot = rotate_uv(interped_uv['u'], interped_uv['v'], target_grid['angle_dx'])
    # Subset onto u and v points.
    uo = urot.isel(nxp=slice(0, None, 2), nyp=slice(1, None, 2)).rename({'nxp': 'xq', 'nyp': 'yh'})
    uo.name = 'u'
    vo = vrot.isel(nxp=slice(1, None, 2), nyp=slice(0, None, 2)).rename({'nxp': 'xh', 'nyp': 'yq'})
    vo.name = 'v'
    
    interped = (
        xarray.merge((uo, vo))
        .transpose('time', 'zl', 'yh', 'yq', 'xh', 'xq')
    )

    return interped


def write_initial(soda_file, vgrid_file, grid_file, start_date, output_file):
    """Interpolate initial conditions for MOM from a SODA file and write to a new file.
    Args:
        soda_file (str): Path to SODA file to use for initial conditions.
        vgrid_file (str): Path to vertical grid to interpolate data to.
        grid_file (str): Path to horizontal grid file (ocean_hgrid.nc) to interpolate data to.
        start_date (np.datetime64): Overwrite the SODA datetime with this datetime. Useful if model start date and SODA 5-day dates do not match.
        output_file (str): Write resulting initial conditions to this file.
    """
    vgrid = xarray.open_dataarray(vgrid_file)
    z = vgrid_to_layers(vgrid)
    ztarget = xarray.DataArray(
        z,
        name='zl',
        dims=['zl'], 
        coords={'zl': z}
    )

    soda = (
        xarray.open_dataset(soda_file)
        .rename({'st_ocean': 'z'})
        [['temp', 'salt', 'ssh', 'u', 'v']]
    )

    # Interpolate SODA vertically onto target grid.
    # Depths below bottom of SODA are filled by extrapolating the deepest available value.
    revert = soda.interp(z=ztarget, kwargs={'fill_value': 'extrapolate'}).ffill('zl', limit=None)

    # Split SODA into data on tracer and velocity points
    tracers = revert[['temp', 'salt', 'ssh']].rename({'xt_ocean': 'lon', 'yt_ocean': 'lat'})
    velocity = revert[['u', 'v']].rename({'xu_ocean': 'lon', 'yu_ocean': 'lat'})

    # Horizontally interpolated the vertically interpolated
    # and flooded data onto the MOM grid.
    grid = xarray.open_dataset(grid_file)

    interped = xarray.merge((
        interpolate_flood_tracers(tracers, grid),
        interpolate_flood_velocity(velocity, grid)
    ))

    # Overwrite the SODA file time with the intended model start date.
    interped['time'] = (('time', ), [start_date])

    # Fix output metadata, including removing all _FillValues.
    all_vars = list(interped.data_vars.keys()) + list(interped.coords.keys())
    encodings = {v: {'_FillValue': None} for v in all_vars}
    encodings['time'].update({'dtype':'float64', 'calendar': 'gregorian'})
    interped['zl'].attrs = {
        'units': 'meter',
        'cartesian_axis': 'Z',
        'positive': 'down'
    }

    interped.to_netcdf(
        output_file,
        format='NETCDF3_64BIT',
        engine='netcdf4',
        encoding=encodings,
        unlimited_dims='time'
    )



In [2]:
soda_file = '/Users/james/Downloads/soda3.12.2_5dy_ocean_reg_1993_01_04.nc'
start_date = np.datetime64('1993-01-04T00:00:00')
# Used in filename below, don't change
start_str = np.datetime_as_string(start_date, unit='D')

# Save the ICs here:
output_file = f'/Users/james/Documents/Github/esm_lab/obc_ic/nwa_shared/soda_ic_75z_{start_str}.nc'

# Model vertical grid:
vgrid_file = '/Users/james/Documents/Github/esm_lab/obc_ic/vgrid_75_2m.nc'

# Model horizontal grid:
grid_file = '/Users/james/Downloads/nwa_ocean_hgrid.nc'


vgrid = xarray.open_dataarray(vgrid_file)
z = vgrid_to_layers(vgrid)
ztarget = xarray.DataArray(
    z,
    name='zl',
    dims=['zl'], 
    coords={'zl': z}
)

soda = (
    xarray.open_dataset(soda_file)
    .rename({'st_ocean': 'z'})
    [['temp', 'salt', 'ssh', 'u', 'v']]
)

# Interpolate SODA vertically onto target grid.
# Depths below bottom of SODA are filled by extrapolating the deepest available value.
revert = soda.interp(z=ztarget, kwargs={'fill_value': 'extrapolate'}).ffill('zl', limit=None)

# Split SODA into data on tracer and velocity points
tracers = revert[['temp', 'salt', 'ssh']].rename({'xt_ocean': 'lon', 'yt_ocean': 'lat'})
velocity = revert[['u', 'v']].rename({'xu_ocean': 'lon', 'yu_ocean': 'lat'})

# Horizontally interpolated the vertically interpolated
# and flooded data onto the MOM grid.
grid = xarray.open_dataset(grid_file)

interped = xarray.merge((
    interpolate_flood_tracers(tracers, grid),
    interpolate_flood_velocity(velocity, grid)
))

# Overwrite the SODA file time with the intended model start date.
interped['time'] = (('time', ), [start_date])

# Fix output metadata, including removing all _FillValues.
all_vars = list(interped.data_vars.keys()) + list(interped.coords.keys())
encodings = {v: {'_FillValue': None} for v in all_vars}
encodings['time'].update({'dtype':'float64', 'calendar': 'gregorian'})
interped['zl'].attrs = {
    'units': 'meter',
    'cartesian_axis': 'Z',
    'positive': 'down'
}



  ds_out = xr.apply_ufunc(
  ds_out = xr.apply_ufunc(


In [3]:
interped

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 148.32 MiB 148.32 MiB Shape (1, 75, 360, 720) (1, 75, 360, 720) Count 79 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  720  360  75,

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 148.32 MiB 148.32 MiB Shape (1, 75, 360, 720) (1, 75, 360, 720) Count 79 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  720  360  75,

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.98 MiB,1.98 MiB
Shape,"(1, 360, 720)","(1, 360, 720)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 1.98 MiB 1.98 MiB Shape (1, 360, 720) (1, 360, 720) Count 5 Tasks 1 Chunks Type float64 numpy.ndarray",720  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.98 MiB,1.98 MiB
Shape,"(1, 360, 720)","(1, 360, 720)"
Count,5 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.52 MiB,148.52 MiB
Shape,"(1, 75, 360, 721)","(1, 75, 360, 721)"
Count,167 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 148.52 MiB 148.52 MiB Shape (1, 75, 360, 721) (1, 75, 360, 721) Count 167 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  721  360  75,

Unnamed: 0,Array,Chunk
Bytes,148.52 MiB,148.52 MiB
Shape,"(1, 75, 360, 721)","(1, 75, 360, 721)"
Count,167 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.73 MiB,148.73 MiB
Shape,"(1, 75, 361, 720)","(1, 75, 361, 720)"
Count,167 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 148.73 MiB 148.73 MiB Shape (1, 75, 361, 720) (1, 75, 361, 720) Count 167 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  720  361  75,

Unnamed: 0,Array,Chunk
Bytes,148.73 MiB,148.73 MiB
Shape,"(1, 75, 361, 720)","(1, 75, 361, 720)"
Count,167 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [4]:
interped.salt

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 148.32 MiB 148.32 MiB Shape (1, 75, 360, 720) (1, 75, 360, 720) Count 79 Tasks 1 Chunks Type float64 numpy.ndarray",1  1  720  360  75,

Unnamed: 0,Array,Chunk
Bytes,148.32 MiB,148.32 MiB
Shape,"(1, 75, 360, 720)","(1, 75, 360, 720)"
Count,79 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [6]:
interped.salt.isel(time=0, zl=0).values

array([[33.29390547, 33.24114025, 33.18731486, ..., 36.5299005 ,
        36.52812307, 36.52581756],
       [33.28601805, 33.23322574, 33.17940591, ..., 36.52740275,
        36.52498074, 36.52231852],
       [33.27916829, 33.22636577, 33.17253654, ..., 36.52462208,
        36.52195954, 36.51905741],
       ...,
       [35.23175759, 35.33545508, 35.43469377, ..., 29.6719132 ,
        29.6135553 , 29.5840168 ],
       [35.13257513, 35.2378583 , 35.33874478, ..., 29.73883376,
        29.63723769, 29.60085991],
       [35.03775638, 35.14469373, 35.24717229, ..., 29.8102779 ,
        29.70612139, 29.6239304 ]])

In [9]:
interped.to_netcdf(
    output_file,
    format='NETCDF3_64BIT',
    engine='netcdf4',
    encoding=encodings,
    unlimited_dims='time'
)

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [2]:
def main():
    # Use SODA data centered on 1992-12-30.
    # Model start date is 1993-01-01.
    # https://dsrs.atmos.umd.edu/DATA/soda3.12.2/REGRIDED/ocean/soda3.12.2_5dy_ocean_reg_1993_01_04.nc
    soda_file = '/Users/james/Downloads/soda3.12.2_5dy_ocean_reg_1993_01_04.nc'
    start_date = np.datetime64('1993-01-04T00:00:00')
    # Used in filename below, don't change
    start_str = np.datetime_as_string(start_date, unit='D')
    
    # Save the ICs here:
    output_file = f'/Users/james/Documents/Github/esm_lab/obc_ic/nwa_shared/soda_ic_75z_{start_str}.nc'

    # Model vertical grid:
    vgrid_file = '/Users/james/Documents/Github/esm_lab/obc_ic/vgrid_75_2m.nc'

    # Model horizontal grid:
    grid_file = '/Users/james/Downloads/nwa_ocean_hgrid.nc'

    write_initial(soda_file, vgrid_file, grid_file, start_date, output_file)


if __name__ == '__main__':
    main()

  ds_out = xr.apply_ufunc(
  ds_out = xr.apply_ufunc(
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt

