In [2]:
import xarray as xr
import numpy as np
import xoak
from matplotlib import pyplot as plt
from cmocean import cm # for oceanography-specific colormaps
from tqdm import tqdm
#import parcels

In [3]:
### to do
def interp_fesom(
    path1 = None ,
    mesh_file = None, 
    u_file =  None,
    v_file = None,
    w_file = None,
):
    ds_mesh = xr.open_dataset(path1+mesh_file)
    #now we define new coords
    ds_mesh = ds_mesh.assign_coords(
        nod2=list(range(1, ds_mesh.sizes["nod2"]+1)), 
        elem=list(range(1,ds_mesh.sizes['elem']+1)),
    )
    
    #corners
    elem_corner_lons = ds_mesh.lon.sel(nod2=ds_mesh.face_nodes)
    elem_corner_lats = ds_mesh.lat.sel(nod2=ds_mesh.face_nodes)
    
    max_elem_lon_range = 0.2
    tri_overlap=(elem_corner_lons.max('n3') - elem_corner_lons.min('n3')) > max_elem_lon_range
    
    near_channel_width =4
    channel_width = 4.5
    elem_corner_lons_unglued = xr.where(tri_overlap & (elem_corner_lons > near_channel_width), 
                                       elem_corner_lons - channel_width, elem_corner_lons)
    
    
    elem_center_lons_unglued = elem_corner_lons_unglued.mean('n3')
    elem_center_lats = elem_corner_lats.mean('n3')
    
    elem_center_lons = elem_corner_lons.mean('n3')
    
    ## assign coordinates to the mesh
    ds_mesh = ds_mesh.assign_coords(
        elem_center_lons=elem_center_lons_unglued,
        elem_center_lats=elem_center_lats,
    )
    #nearest neighbour interpolation
    ds_mesh.xoak.set_index(['elem_center_lats','elem_center_lons'], 'sklearn_geo_balltree')
    
    channel_lon_bds = (0,4.5) # use inmutable objects
    channel_lat_bds = (0,18)
    nlon = 2*72 
    nlat = 2*292

    grid_lon = xr.DataArray(np.linspace(*channel_lon_bds,nlon), 
                            dims=('grid_lon',))
    grid_lat = xr.DataArray(np.linspace(*channel_lat_bds,nlat),
                            dims=('grid_lat',))
    
    #reorder the lat and lon into a C grid
    target_lon, target_lat = xr.broadcast(grid_lon, grid_lat)
    
    #select the grid elements
    grid_elems = ds_mesh.xoak.sel(
        elem_center_lats = target_lat,
        elem_center_lons = target_lon,
    ).elem
    
    grid_elems = grid_elems.assign_coords(
        target_lat = target_lat,
        target_lon = target_lon,
    )
    
    grid_elems = grid_elems.assign_coords(
        grid_lat=grid_lat,
        grid_lon=grid_lon,
    )
    
    ## modify the mesh for nodes and 
    ds_mesh = ds_mesh.assign_coords(
        lat=("nod2", ds_mesh.lat.data.flatten()),
        lon=("nod2", ds_mesh.lon.data.flatten()),
    )
    #
    # Ensure the xoak index 
    ds_mesh.xoak.set_index(["lat", "lon"], "sklearn_geo_balltree")
    
    #-------------get the nod2grids
    #grid_nodes
    grid_nodes = ds_mesh.xoak.sel(
        lat = target_lat,
        lon = target_lon,
    ).nod2

    #%--- open the files
    ds_u = xr.open_dataset(path1+u_file).compute()
    ds_v = xr.open_dataset(path1+v_file).compute()
    ds_w = xr.open_dataset(path1+w_file).compute()

    #%--- define the depths
    z_target = xr.DataArray(
        sorted(list(ds_mesh.nz1.data) + list(ds_mesh.nz.data[[0,-1]])),
        dims = 'z',)
    
    u_interp = ds_u.u.isel(elem=grid_elems - 1).sel(nz1=z_target,method = 'nearest') 
    v_interp = ds_v.v.isel(elem=grid_elems - 1).sel(nz1=z_target,method = 'nearest') 
    w_interp = ds_w.w.isel(nod2=grid_nodes - 1).sel(nz=z_target,method = 'nearest')

    w_interp=w_interp.assign_coords(
        grid_lon = target_lon.isel(grid_lat = 0, drop = True),
        grid_lat = target_lat.isel(grid_lon = 0, drop = True),
        z = z_target,
        )
    u_interp=u_interp.assign_coords(
        grid_lon = target_lon.isel(grid_lat = 0, drop = True),
        grid_lat = target_lat.isel(grid_lon = 0, drop = True),
        z = z_target,
        )
    v_interp=v_interp.assign_coords(
        grid_lon = target_lon.isel(grid_lat = 0, drop = True),
        grid_lat = target_lat.isel(grid_lon = 0, drop = True),
        z = z_target,
        )

    u_interp= u_interp.rename({'nz1':'nz'})
    v_interp= v_interp.rename({'nz1':'nz'})

    u_interp=u_interp.transpose('time','z','grid_lat','grid_lon')
    v_interp=v_interp.transpose('time','z','grid_lat','grid_lon')
    w_interp=w_interp.transpose('time','z','grid_lat','grid_lon')

    return u_interp, v_interp, w_interp

In [4]:
# Define the paths to your FESOM data files
path1 = "../../FESOM_data/channel/"  # Adjust the path to your FESOM dataset
mesh_file = "fesom.mesh.diag.nc"  # The FESOM mesh file

for year in tqdm(range(1960, 2057+1,1)):
    
    u_file = f"u.fesom.{year}.nc"  # File containing U velocity
    v_file = f"v.fesom.{year}.nc"  # File containing V velocity
    w_file = f"w.fesom.{year}.nc"  # File containing W velocity
    
    u_interp, v_interp, w_interp =  interp_fesom(path1 = path1,
                                                 mesh_file = mesh_file, 
                                                 u_file = u_file,
                                                 v_file = v_file,
                                                 w_file = w_file,
                                                )
    
    u_interp.drop_encoding().to_netcdf(f'/gxfs_work/geomar/smomw662/FESOM_data/channel_interp/u.{year}.nc') 
    v_interp.drop_encoding().to_netcdf(f'/gxfs_work/geomar/smomw662/FESOM_data/channel_interp/v.{year}.nc')
    w_interp.drop_encoding().to_netcdf(f'/gxfs_work/geomar/smomw662/FESOM_data/channel_interp/w.{year}.nc')

100%|██████████████████████████████████████████████████████████████████████| 4/4 [07:34<00:00, 113.62s/it]


In [None]:
# w_interp.sel(grid_lon = 2.2, grid_lat = 8, method = 'nearest').plot(x='time',y='z',ylim=(4000, 0))

In [None]:
# for year in range(1960, 2057+1,1):
#     u_interp, v_interp, w_interp =  interp_fesom()
    
#     u_interp.drop_encoding().to_netcdf(f'/gxfs_work/geomar/smomw662/FESOM_data/channel_interp/u.{year}.nc') 
#     v_interp.drop_encoding().to_netcdf()
#     w_interp.drop_encoding().to_netcdf()
    