In [None]:
from datetime import timedelta as delta
import datetime
from glob import glob
from operator import attrgetter

import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.mpl.ticker as cticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

from IPython.display import HTML
from matplotlib.animation import FuncAnimation, PillowWriter, writers
import matplotlib.animation as animation
from matplotlib import colors

from parcels import (
    ErrorCode,
    AdvectionRK4,
    FieldSet,
    JITParticle,
    ParticleSet,
    Variable,
    XarrayDecodedFilter,
    download_example_dataset,
    logger,Geographic,GeographicPolar
)

# Add a filter for the xarray decoding warning
logger.addFilter(XarrayDecodedFilter())

In [None]:
class PParticle(JITParticle):
    distance = Variable('distance', initial=0., dtype=np.float32) # the distance travelled
    prev_lon = Variable('prev_lon', dtype=np.float32, to_write=False,
                        initial=attrgetter('lon')) # the previous longitude
    prev_lat = Variable('prev_lat', dtype=np.float32, to_write=False,
                        initial=attrgetter('lat')) # the previous longitude

    
# Keeping track of the total distance travelled by a particle:
def TotalDistance(particle, fieldset, time):
    """Calculate the distance in latitudinal direction
    (using 1.11e2 kilometer per degree latitude)"""
    lat_dist = (particle.lat - particle.prev_lat) * 1.11e2
    lon_dist = (
        (particle.lon - particle.prev_lon)
        * 1.11e2
        * math.cos(particle.lat * math.pi / 180)
    )
    # Calculate the total Euclidean distance travelled by the particle
    particle.distance += math.sqrt(math.pow(lon_dist, 2) + math.pow(lat_dist, 2))

    # Set the stored values for next iteration
    particle.prev_lon = particle.lon
    particle.prev_lat = particle.lat

def DeleteParticle(particle, fieldset, time):
    particle.delete()

In [None]:
# Case of Pachia Ammos
# ---------------------------------------------------------
# Config
# ---------------------------------------------------------
data_folder = "/mnt/data/CMEMS-download/" # modify to your data path

init_lon, init_lat = 25.808353, 35.114241   # Pachia ammos case

months = [('jan', 1), ('feb', 2), ('mar', 3), ('apr',4), ('may', 5),\
          ('jun', 6), ('jul', 7), ('aug', 8), ('sep', 9), ('oct', 10),\
          ('nov', 11), ('dec', 12)]  

n_particles = 1        # per release
runtime_days = 60
output_dt_hours = 6
StokesD = True
oname = 'h'

# ---------------------------------------------------------
# Main loop over years
# ---------------------------------------------------------
for y in range(2017, 2022):

    print(f"Building fieldsets for year {y} ...")

    # --- Build currents fieldset (once per year) ---
    cur_files = sorted(glob(f"{data_folder}/r_med-cmcc-cur-rean-h-{y-1}-*.nc")) + \
                sorted(glob(f"{data_folder}/r_med-cmcc-cur-rean-h-{y}-*.nc"))
    variables = {"U": "uo", "V": "vo"}
    dimensions = {
        "U": {"lon": "lon", "lat": "lat", "time": "time"},
        "V": {"lon": "lon", "lat": "lat", "time": "time"},
    }
    fieldset_cs = FieldSet.from_netcdf(
        cur_files, variables, dimensions, allow_time_extrapolation=True
    )

    if StokesD:
        # --- Build Stokes drift fieldset  ---
        sd_files = sorted(glob(f"{data_folder}/r_med-hcmr-wav-rean-h-{y-1}-*.nc")) + \
                   sorted(glob(f"{data_folder}/r_med-hcmr-wav-rean-h-{y}-*.nc"))
        SD_variables = {"U": "VSDX", "V": "VSDY"}
        SD_dimensions = {
            "U": {"lon": "lon", "lat": "lat", "time": "time"},
            "V": {"lon": "lon", "lat": "lat", "time": "time"},
        }
        fieldset_sd = FieldSet.from_netcdf(
            sd_files, SD_variables, SD_dimensions, allow_time_extrapolation=True
        )
    
        # --- Merge currents + Stokes drift into one fieldset ---
        fieldset = FieldSet(
            U=fieldset_cs.U + fieldset_sd.U,
            V=fieldset_cs.V + fieldset_sd.V,
        )
    else:
        fieldset = fieldset_cs
    
    ds = xr.open_mfdataset(cur_files)
    point_ds = ds.sel(lon=init_lon, lat=init_lat, method='nearest')

    lon = point_ds.lon.values
    lat = point_ds.lat.values
    
    # ---------------------------------------------------------
    # Build ParticleSet with ALL releases for this year
    # ---------------------------------------------------------

    for m in months:
        lons, lats, times, depths = [], [], [], []
        for d in range(1, 31):
            if m[0] == "feb" and (d == 29 or d ==30):  # leap year fix
                d = 28
            lons.append(lon)
            lats.append(lat)
            times.append(datetime.datetime(y, m[1], d))
            depths.append(1.0)

        pset = ParticleSet.from_list(
            fieldset=fieldset,
            pclass=PParticle,   
            lon=lons,
            lat=lats,
            time=times,
            depth=depths,
        )
    
        # ---------------------------------------------------------
        # Run all releases in one execution
        # ---------------------------------------------------------
        kernels = pset.Kernel(AdvectionRK4) + TotalDistance
        output_file = pset.ParticleFile(
            name=f"./Particles-PA-{oname}-{m[0]}-{y}.zarr",
            outputdt=delta(hours=output_dt_hours),
        )
    
        print(f"Running year {y}, month {m[0]} with {len(lons)} releases...")
        pset.execute(
            kernels,
            runtime=delta(days=runtime_days),
            dt=-delta(minutes=15),
            output_file=output_file,
            recovery={ErrorCode.ErrorOutOfBounds: DeleteParticle},
        )