In [2]:
import xarray as xr
import numpy as np
import xoak
from matplotlib import pyplot as plt
from cmocean import cm # for oceanography-specific colormaps

In [4]:
ds_mesh = xr.open_dataset("../../FESOM_data/channel/fesom.mesh.diag.nc")
#now we define new coords
ds_mesh = ds_mesh.assign_coords(
    nod2=list(range(1, ds_mesh.sizes["nod2"]+1)), 
    elem=list(range(1,ds_mesh.sizes['elem']+1)),
)

#corners
elem_corner_lons = ds_mesh.lon.sel(nod2=ds_mesh.face_nodes)
elem_corner_lats = ds_mesh.lat.sel(nod2=ds_mesh.face_nodes)

max_elem_lon_range = 0.2
tri_overlap=(elem_corner_lons.max('n3') - elem_corner_lons.min('n3')) > max_elem_lon_range

near_channel_width =4
channel_width = 4.5
elem_corner_lons_unglued = xr.where(tri_overlap & (elem_corner_lons > near_channel_width), 
                                   elem_corner_lons - channel_width, elem_corner_lons)


elem_center_lons_unglued = elem_corner_lons_unglued.mean('n3')
elem_center_lats = elem_corner_lats.mean('n3')

elem_center_lons = elem_corner_lons.mean('n3')

## assign coordinates to the mesh
ds_mesh = ds_mesh.assign_coords(
    elem_center_lons=elem_center_lons_unglued,
    elem_center_lats=elem_center_lats,
)
#nearest neighbour interpolation
ds_mesh.xoak.set_index(['elem_center_lats','elem_center_lons'], 'sklearn_geo_balltree')

channel_lon_bds = (0,4.5) # use inmutable objects
channel_lat_bds = (0,18)
number_lon = 2*72 
number_lat = 2*292

# w_lon = number_lon
# w_lat = number_lat
# w_lon = int(2*51.5)
# w_lat = int(2*206)

grid_lon = xr.DataArray(np.linspace(*channel_lon_bds,number_lon), 
                        dims=('grid_lon',))
grid_lat = xr.DataArray(np.linspace(*channel_lat_bds,number_lat),
                        dims=('grid_lat',))

#reorder the lat and lon into a C grid
target_lon, target_lat = xr.broadcast(grid_lon, grid_lat)

#select the grid elements
grid_elems = ds_mesh.xoak.sel(
    elem_center_lats = target_lat,
    elem_center_lons = target_lon,
).elem

grid_elems = grid_elems.assign_coords(
    target_lat = target_lat,
    target_lon = target_lon,
)

grid_elems = grid_elems.assign_coords(
    grid_lat=grid_lat,
    grid_lon=grid_lon,
)

## modify the mesh for nodes and 
ds_mesh = ds_mesh.assign_coords(
    lat=("nod2", ds_mesh.lat.data.flatten()),
    lon=("nod2", ds_mesh.lon.data.flatten()),
)
#
# Ensure the xoak index 
ds_mesh.xoak.set_index(["lat", "lon"], "sklearn_geo_balltree")

#-------------get the nod2grids
#grid_nodes
grid_nodes = ds_mesh.xoak.sel(
    lat = target_lat,
    lon = target_lon,
).nod2

## Equal depth levels

In [15]:
za = ds_mesh.nz.values #41
zb = ds_mesh.nz1.values #40

zc = np.array(sorted(np.concatenate((za, zb))))
#print(zc)

zg = np.sort(np.hstack((0.5 * (zc[0:-1] + zc[1:]), zc)))

#nz grid 
nz_grid = ds_mesh.sel(nz = zg, method = 'nearest').nz
print(nz_grid.astype(int))

nz1_grid = ds_mesh.sel(nz1 = zg, method = 'nearest').nz1
print(nz1_grid.astype(int))

<xarray.DataArray 'nz' (nz: 161)> Size: 1kB
array([   0,    0,    9,    9,    9,    9,    9,   18,   18,   18,   18,
         29,   29,   29,   41,   41,   41,   41,   55,   55,   55,   55,
         69,   69,   69,   69,   85,   85,   85,   85,  103,  103,  103,
        103,  122,  122,  122,  122,  144,  144,  144,  144,  144,  167,
        167,  167,  193,  193,  193,  193,  221,  221,  221,  221,  252,
        252,  252,  252,  252,  287,  287,  287,  287,  324,  324,  324,
        324,  366,  366,  366,  412,  412,  412,  412,  462,  462,  462,
        462,  517,  517,  517,  517,  578,  578,  578,  578,  578,  645,
        645,  645,  718,  718,  718,  718,  799,  799,  799,  799,  888,
        888,  888,  888,  986,  986,  986,  986,  986, 1094, 1094, 1094,
       1212, 1212, 1212, 1212, 1343, 1343, 1343, 1343, 1486, 1486, 1486,
       1486, 1644, 1644, 1644, 1644, 1644, 1817, 1817, 1817, 2008, 2008,
       2008, 2008, 2008, 2218, 2218, 2218, 2449, 2449, 2449, 2449, 2703,
       

## Load the data U,V,W

In [10]:
ds_u = xr.open_mfdataset('../../FESOM_data/channel/u.fesom.2005.nc',
                         chunks ={'time':1, 'nz1': 1}).isel(time=slice(0,4))
# first selecting only the surface nz1=0
ds_v = xr.open_mfdataset('../../FESOM_data/channel/v.fesom.2005.nc',
                         chunks = {'time':1, 'nz1':1}).isel(time=slice(0,4))

ds_w = xr.open_mfdataset('../../FESOM_data/channel/w.fesom.2005.nc',
                         chunks = {'time':1, 'nz':1}).isel(time=slice(0,4))




In [19]:
U_grid = ds_u.u.isel(elem=grid_elems - 1).interp(nz1=nz1_grid,method = 'nearest') 
V_grid = ds_v.v.isel(elem=grid_elems - 1).interp(nz1=nz1_grid,method = 'nearest') 
W_grid = ds_w.w.isel(nod2=grid_nodes - 1).interp(nz=nz_grid,method = 'nearest')

In [20]:
print(U_grid.shape)
print(V_grid.shape)
print(W_grid.shape)

(4, 161, 144, 584)
(4, 161, 144, 584)
(4, 161, 144, 584)


In [22]:
ds_uv_grid= xr.Dataset({
    'U':U_grid,
    "V":V_grid,
    "W":W_grid,
})

In [24]:
ds_uv_grid
## Keep only one Z and drop the other one

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,12.91 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 36, 584)"
Dask graph,16 chunks in 16 graph layers,16 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 206.60 MiB 12.91 MiB Shape (4, 161, 144, 584) (1, 161, 36, 584) Dask graph 16 chunks in 16 graph layers Data type float32 numpy.ndarray",4  1  584  144  161,

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,12.91 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 36, 584)"
Dask graph,16 chunks in 16 graph layers,16 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,12.91 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 36, 584)"
Dask graph,16 chunks in 16 graph layers,16 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 206.60 MiB 12.91 MiB Shape (4, 161, 144, 584) (1, 161, 36, 584) Dask graph 16 chunks in 16 graph layers Data type float32 numpy.ndarray",4  1  584  144  161,

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,12.91 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 36, 584)"
Dask graph,16 chunks in 16 graph layers,16 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,6.46 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 18, 584)"
Dask graph,36 chunks in 16 graph layers,36 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 206.60 MiB 6.46 MiB Shape (4, 161, 144, 584) (1, 161, 18, 584) Dask graph 36 chunks in 16 graph layers Data type float32 numpy.ndarray",4  1  584  144  161,

Unnamed: 0,Array,Chunk
Bytes,206.60 MiB,6.46 MiB
Shape,"(4, 161, 144, 584)","(1, 161, 18, 584)"
Dask graph,36 chunks in 16 graph layers,36 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Now Parcels

In [21]:
from parcels import ParticleSet
from parcels import JITParticle
from parcels import AdvectionRK4
from datetime import timedelta
import numpy as np
from parcels import FieldSet

In [23]:
fieldset = FieldSet.from_xarray_dataset(
    ds_uv_grid.transpose('time','grid_lat','grid_lon'),
    variables={'U':"U", "V":"V", "W":":W"},
    dimensions={'lon':'grid_lon',
                'lat':'grid_lat',
                'time':'time'},
    time_periodic=False,
    allow_time_extrapolation=True,
)


ValueError: ('time', 'grid_lat', 'grid_lon') must be a permuted list of FrozenMappingWarningOnValuesAccess({'time': 4, 'grid_lon': 144, 'grid_lat': 584, 'nz1': 161, 'nz': 161}), unless `...` is included