In [1]:
import xarray as xr
import numpy as np
import xoak
from matplotlib import pyplot as plt
from cmocean import cm # for oceanography-specific colormaps
from tqdm import tqdm
import sys
import os
#import parcels

In [None]:
# #%------------- Set the paths
year = sys.argv[10]  # The tenth argument
#mesh load
mesh_file = os.path.join(path1, mesh_fn)
#data load
u_file = os.path.join(path1, f"u.fesom.{year}.nc")
v_file = os.path.join(path1, f"v.fesom.{year}.nc")
w_file = os.path.join(path1, f"w.fesom.{year}.nc")


In [2]:

#%------------- Particles
## Set the number of particles
num_particles = int(sys.argv[1])

## Set the location of the particles 
lon_start = np.random.uniform(2,3,size=(num_particles,)) 
lat_start = np.random.uniform(7.5, 12.5, size=(num_particles,))

## set the tracking time
days = int(sys.argv[2])
minutes = 20

## record the particles every timestep of
hours=4

In [3]:
ds_mesh = xr.open_dataset(path1+mesh_fn)
#now we define new coords
ds_mesh = ds_mesh.assign_coords(
    nod2=list(range(1, ds_mesh.sizes["nod2"]+1)), 
    elem=list(range(1,ds_mesh.sizes['elem']+1)),
)

#corners
elem_corner_lons = ds_mesh.lon.sel(nod2=ds_mesh.face_nodes)
elem_corner_lats = ds_mesh.lat.sel(nod2=ds_mesh.face_nodes)

max_elem_lon_range = 0.2
tri_overlap=(elem_corner_lons.max('n3') - elem_corner_lons.min('n3')) > max_elem_lon_range

near_channel_width =4
channel_width = 4.5
elem_corner_lons_unglued = xr.where(tri_overlap & (elem_corner_lons > near_channel_width), 
                                   elem_corner_lons - channel_width, elem_corner_lons)


elem_center_lons_unglued = elem_corner_lons_unglued.mean('n3')
elem_center_lats = elem_corner_lats.mean('n3')

elem_center_lons = elem_corner_lons.mean('n3')

## assign coordinates to the mesh
ds_mesh = ds_mesh.assign_coords(
    elem_center_lons=elem_center_lons_unglued,
    elem_center_lats=elem_center_lats,
)
#nearest neighbour interpolation
ds_mesh.xoak.set_index(['elem_center_lats','elem_center_lons'], 'sklearn_geo_balltree')

channel_lon_bds = (0,4.5) # use inmutable objects
channel_lat_bds = (0,18)
number_lon = 2*72 
number_lat = 2*292

# w_lon = number_lon
# w_lat = number_lat
# w_lon = int(2*51.5)
# w_lat = int(2*206)

grid_lon = xr.DataArray(np.linspace(*channel_lon_bds,number_lon), 
                        dims=('grid_lon',))
grid_lat = xr.DataArray(np.linspace(*channel_lat_bds,number_lat),
                        dims=('grid_lat',))

#reorder the lat and lon into a C grid
target_lon, target_lat = xr.broadcast(grid_lon, grid_lat)

#select the grid elements
grid_elems = ds_mesh.xoak.sel(
    elem_center_lats = target_lat,
    elem_center_lons = target_lon,
).elem

grid_elems = grid_elems.assign_coords(
    target_lat = target_lat,
    target_lon = target_lon,
)

grid_elems = grid_elems.assign_coords(
    grid_lat=grid_lat,
    grid_lon=grid_lon,
)

## modify the mesh for nodes and 
ds_mesh = ds_mesh.assign_coords(
    lat=("nod2", ds_mesh.lat.data.flatten()),
    lon=("nod2", ds_mesh.lon.data.flatten()),
)
#
# Ensure the xoak index 
ds_mesh.xoak.set_index(["lat", "lon"], "sklearn_geo_balltree")

#-------------get the nod2grids
#grid_nodes
grid_nodes = ds_mesh.xoak.sel(
    lat = target_lat,
    lon = target_lon,
).nod2

## Equal depth levels

In [4]:
za = ds_mesh.nz.values #41
zb = ds_mesh.nz1.values #40

zc = np.array(sorted(np.concatenate((za, zb))))
#print(zc)

zg = np.sort(np.hstack((0.5 * (zc[0:-1] + zc[1:]), zc)))

#nz grid 
nz_grid = ds_mesh.sel(nz = zg, method = 'nearest').nz
print(nz_grid.astype(int))

nz1_grid = ds_mesh.sel(nz1 = zg, method = 'nearest').nz1
print(nz1_grid.astype(int))

<xarray.DataArray 'nz' (nz: 161)> Size: 1kB
array([   0,    0,    9,    9,    9,    9,    9,   18,   18,   18,   18,
         29,   29,   29,   41,   41,   41,   41,   55,   55,   55,   55,
         69,   69,   69,   69,   85,   85,   85,   85,  103,  103,  103,
        103,  122,  122,  122,  122,  144,  144,  144,  144,  144,  167,
        167,  167,  193,  193,  193,  193,  221,  221,  221,  221,  252,
        252,  252,  252,  252,  287,  287,  287,  287,  324,  324,  324,
        324,  366,  366,  366,  412,  412,  412,  412,  462,  462,  462,
        462,  517,  517,  517,  517,  578,  578,  578,  578,  578,  645,
        645,  645,  718,  718,  718,  718,  799,  799,  799,  799,  888,
        888,  888,  888,  986,  986,  986,  986,  986, 1094, 1094, 1094,
       1212, 1212, 1212, 1212, 1343, 1343, 1343, 1343, 1486, 1486, 1486,
       1486, 1644, 1644, 1644, 1644, 1644, 1817, 1817, 1817, 2008, 2008,
       2008, 2008, 2008, 2218, 2218, 2218, 2449, 2449, 2449, 2449, 2703,
       

## Load the data U,V,W

In [5]:
ds_u = xr.open_mfdataset(path1+u_path,
                         chunks ={'time':1, 'nz1': 1})
# first selecting only the surface nz1=0
ds_v = xr.open_mfdataset(path1+v_path,
                         chunks = {'time':1, 'nz1':1})

ds_w = xr.open_mfdataset(path1+w_path,
                         chunks = {'time':1, 'nz':1})



In [None]:
U_grid = ds_u.u.isel(elem=grid_elems - 1).interp(nz1=nz1_grid,method = 'nearest') 
V_grid = ds_v.v.isel(elem=grid_elems - 1).interp(nz1=nz1_grid,method = 'nearest') 
W_grid = ds_w.w.isel(nod2=grid_nodes - 1).interp(nz=nz_grid,method = 'nearest')

In [None]:
print(U_grid.shape)
print(V_grid.shape)
print(W_grid.shape)

In [None]:
ds_uv_grid= xr.Dataset({
    'U':U_grid,
    "V":V_grid,
    "W":W_grid,
})

In [None]:
ds_uv_grid
## Keep only one Z and drop the other one

In [None]:
ds_uv_grid = ds_uv_grid.drop_vars('nz1')

In [None]:
ds_uv_grid['U'] = ds_uv_grid['U'].rename({'nz1':'nz'})
ds_uv_grid['V'] = ds_uv_grid['V'].rename({'nz1':'nz'})

In [None]:
ds_uv_grid = ds_uv_grid.compute()

In [None]:
#print(ds_uv_grid['U'].data.chunks)

## Now Parcels

In [None]:
from parcels import ParticleSet
from parcels import JITParticle
from parcels import AdvectionRK4_3D
from parcels import AdvectionRK4
from datetime import timedelta
import numpy as np
from parcels import FieldSet

In [None]:
fieldset = FieldSet.from_xarray_dataset(
    ds_uv_grid.transpose('time','nz','grid_lat','grid_lon'),
    variables={'U':"U", "V":"V", "W":"W"},
    dimensions={'lon':'grid_lon',
                'lat':'grid_lat',
                'depth':'nz',
                'time':'time',
               },
    time_periodic=False,
    allow_time_extrapolation=False,
)

In [None]:
ds_uv_grid.nz

## Halo

In [None]:
fieldset.add_constant("halo_west", fieldset.U.grid.lon[0])
fieldset.add_constant("halo_east", fieldset.U.grid.lon[-1])
fieldset.add_periodic_halo(zonal=True)

def periodicBC(particle,fielset,time):
    if particle.lon < fieldset.halo_west:
        particle_dlon += fieldset.halo_east - fieldset.halo_west
    elif particle.lon > fieldset.halo_east:
        particle_dlon -= fieldset.halo_east - fieldset.halo_west

In [None]:
## Time and depth initial conditios
time = np.repeat(ds_uv_grid.time[0].data, num_particles)  # Assign the same time to all particles
depth = np.random.uniform(10,50, size=num_particles)  # Choose random depths
#time = np.repeat(ds_uv_grid.time[0], num_particles)  # Assign the same time to all particles


## Initiate particles

In [None]:
#Init particle set
pset = ParticleSet(
    fieldset=fieldset,
    pclass=JITParticle,
    lon = lon_start,
    lat = lat_start,
    depth=depth,
    time=time
) 

# pset = parcels.ParticleSet.from_line(
#     fieldset=fieldset,
#     pclass=parcels.JITParticle,
#     size=10,
#     start=(1.9, 52.5),
#     finish=(3.4, 51.6),
#     depth=1,
# )

# lon = np.random.uniform(2, 3, size=num_particles)  # Longitudes between 2 and 3
# lat = np.random.uniform(7.5, 12.5, size=num_particles)  # Latitudes between 7.5 and 12.5
# depth = np.random.choice(ds_uv_grid.nz, size=num_particles)  # Choose random depths

# lon_start = np.random.uniform(2,3,size=(num_particles,)) 
# lat_start = np.random.uniform(7.5, 12.5, size=(num_particles,))


In [None]:
output_file = pset.ParticleFile(name=out_path+out_fn, 
                                outputdt=timedelta(hours=hours))

In [None]:
## Execute particles
pset.execute(
    [AdvectionRK4_3D,periodicBC],
    runtime=timedelta(days=days),
    dt=timedelta(minutes=minutes),
    output_file= output_file
)
## check out a different advection squeme

## Make a plot

In [None]:
import xarray as xr
import numpy as np
import xoak
from matplotlib import pyplot as plt
from cmocean import cm # for oceanography-specific colormaps
from itertools import zip_longest
from functools import reduce
from operator import add
from pathlib import Path
import tqdm

In [None]:
ds_traj = xr.open_zarr(out_path+out_fn+".zarr")
ds_traj = ds_traj.compute()
ds_traj

In [None]:
ds_traj.isel(trajectory=5).z.plot(marker='.')

In [None]:
## Pre plot
skip_this_step = abs(ds_traj.lon.diff('obs')) > 4.0
ds_traj_nowrap = ds_traj.where(~skip_this_step)
ds_traj_nowrap.isel(trajectory=0).to_pandas().plot.line(
    x='lon', y='lat',
)

In [None]:
def line_between(start,end):
    """Find the intermediate points on a line from (x0,y0) to (x1,y1).

    Parameters
    ------------
    start: tuple
        Contains x0 and y0
    end: tuple
        Contains x1 and y1

    Returns
    --------
    list
        List of all intermediate points (x,y)

    """
    x0,y0 = start
    x1,y1 = end
    #This extracts the individual coordinates from the start and end tuples.
    N = max(abs(x1 - x0) + 1, abs(y1 - y0) + 1) #calculate the number of steps
    #Calculate the incremental step sizes
    dx = (x1 - x0) / (N - 1) #for stepping in lon or x
    dy = (y1 - y0) / (N - 1) #for stepping in lat or y
    #the steps secure the evenly spaced points between the start and end

    #Generate the intermediate points
    xx = (round(x0 + n * dx) for n in range(N))
    yy = (round(y0 + n * dy) for n in range(N))
    #Combines the x and y coordinates into a list of tuples 
    return list(zip(xx,yy))

def line_between_sequence(points):
    """ Fill in lines on all segments of points.

    Parameters
    ----------
    points: list
        List of points (x,y).

    Returns
    -------
    list
        List of points(x,y) with all segments filled in.

    """

    segments = [
    line_between(start,end)[:-1]
    for start, end in zip(points[:-1], points[1:])
    ] + [points[-1:], ]
    return reduce(add, segments)

In [None]:
### Unrolling
ad_lon = 0 + 4.5 * (ds_traj.lon.diff('obs') < -4) - 4.5 * (ds_traj.lon.diff('obs') > 4)
lon_unrolled = (ds_traj.lon + ad_lon.cumsum('obs')) #we accumulate the corrections 
lon_unrolled.isel(trajectory=0).plot()

In [None]:
Nx = 45 #~4.5
Ny = 180 #~18.0

In [None]:
#Each ghost image band corresponds to an increment of 4.5 longitude units.
#Creates ghost images in left and right 
pix_replica_lon = lon_unrolled.min().compute().data[()] //4.5 , lon_unrolled.max().compute().data[()] //4.5 + 1
pix_x_unrolled = np.arange(pix_replica_lon[0] * Nx, pix_replica_lon[1] * Nx).astype(int)


In [None]:
lon_bds = (int(pix_replica_lon[0] *4.5), int(pix_replica_lon[1] * 4.5))
lat_bds = (0,18)
lon_bds, lat_bds

In [None]:
dens = np.zeros((Ny,Nx), dtype=int)

for traj in tqdm.tqdm(ds_traj.trajectory.isel(trajectory=slice(None,None,1)).data):
    lon_traj = lon_unrolled.sel(trajectory=traj).compute().data
    lat_traj = ds_traj.lat.sel(trajectory=traj).compute().data
    point_list = list(zip(
        np.digitize(lon_traj, np.linspace(*lon_bds, int((pix_replica_lon[1] - pix_replica_lon[0]) * Nx))).astype(int)-1,
        np.digitize(lat_traj, np.linspace(*lat_bds, Ny)).astype(int) -1,
    ))

    #drop the duplicates
    point_list = [i for i, j in zip_longest(point_list, point_list[1:]) if i !=j]
    pos,count = np.unique(np.array(line_between_sequence(point_list)), axis=0, return_counts = True)
    #wrap back to Nx and Ny
    pos = pos % [Nx,Ny]
    dens[*pos.T[::-1]] +=count
    #break

In [None]:
fig, ax = plt.subplots(1,1)
fig.set_dpi(300)
ax.imshow(dens)