In [None]:
from tqdm import tqdm
import geopandas as gpd
import xarray as xr
import rioxarray
import numpy as np
from shapely.geometry import Point, box
import os
from dateutil.relativedelta import relativedelta
from scipy.spatial import cKDTree


# -------------------------------------------------------------
# 1. Récupération du fichier mensuel du mois précédent
# -------------------------------------------------------------
def get_monthly_netcdf_path(dt, base_folder, stat_type):
    """
    Retourne le fichier mensuel correspondant au mois précédent.
    Exemple : dt = 2019-07-01 → récupère fichier de juin 2019.
    """
    prev_month = dt - relativedelta(months=1)
    yyyymm = prev_month.strftime("%Y%m")

    year_folder = os.path.join(base_folder, str(prev_month.year))
    fname = f"MARS3D_{yyyymm}_{stat_type}.nc"
    fpath = os.path.join(year_folder, fname)

    if not os.path.exists(fpath):
        raise FileNotFoundError(f"Fichier introuvable : {fpath}")

    return fpath

# -------------------------------------------------------------
# 2. Extraction WS / VEL pour un polygone
# -------------------------------------------------------------
def get_ws_vel_for_poly(poly, ncdf_path, n_closest=3):
    """
    Extraction WS/VEL pour un polygone et un fichier NetCDF.
    - Clip vectorisé si possible
    - Fallback : prend jusqu'à n_closest pixels indépendants pour WS et VEL si aucun intersecté.
    """
    ds = xr.open_dataset(ncdf_path, engine="netcdf4")
    ws = ds["WINDSTRESS"].isel(time=0)
    vel = ds["VELOCITY"].isel(time=0, level=-1)

    # Assurer que le CRS est défini
    if ws.rio.crs is None:
        ws = ws.rio.write_crs("EPSG:4326")
    if vel.rio.crs is None:
        vel = vel.rio.write_crs("EPSG:4326")

    # Clip vectorisé
    ws_clip = ws.rio.clip([poly], all_touched=True, drop=False)
    vel_clip = vel.rio.clip([poly], all_touched=True, drop=False)

    ws_vals = ws_clip.values.flatten()
    ws_vals = ws_vals[~np.isnan(ws_vals)]
    vel_vals = vel_clip.values.flatten()
    vel_vals = vel_vals[~np.isnan(vel_vals)]

    # Fallback indépendant si aucun pixel intersecté
    if len(ws_vals) == 0 or len(vel_vals) == 0:
        transform = ws.rio.transform()
        h, w = ws.shape
        xs = np.arange(w) + 0.5
        ys = np.arange(h) + 0.5
        xv, yv = transform * np.meshgrid(xs, ys)

        # Fallback WS
        if len(ws_vals) == 0:
            ws_all = ws.values
            valid_idx = ~np.isnan(ws_all)
            coords = np.column_stack([xv[valid_idx], yv[valid_idx]])
            values = ws_all[valid_idx]
            tree = cKDTree(coords)
            px, py = poly.centroid.x, poly.centroid.y
            _, idx = tree.query([px, py], k=min(n_closest, len(values)))
            ws_vals = values[idx]

        # Fallback VEL (indépendant)
        if len(vel_vals) == 0:
            vel_all = vel.values
            valid_idx = ~np.isnan(vel_all)
            coords = np.column_stack([xv[valid_idx], yv[valid_idx]])
            values = vel_all[valid_idx]
            tree = cKDTree(coords)
            px, py = poly.centroid.x, poly.centroid.y
            _, idx = tree.query([px, py], k=min(n_closest, len(values)))
            vel_vals = values[idx]

    ws_mean = np.mean(ws_vals) if len(ws_vals) > 0 else np.nan
    vel_mean = np.mean(vel_vals) if len(vel_vals) > 0 else np.nan

    ds.close()
    return ws_vals, vel_vals, ws_mean, vel_mean


# -------------------------------------------------------------
# 3. Programme principal
# -------------------------------------------------------------
def main():

    base_folder = "/home/paulinev/Bureau/Marbec_data/BiodivMed/MARS3D/Med_MENOR/Aggregated/CUR-WIND_latlon/Monthly/Med-Est"
    gdf = gpd.read_file("grille_med_est.geojson")

    ws_max1m, ws_min1m, ws_mean1m = [], [], []
    vel_max1m, vel_min1m, vel_mean1m = [], [], []

    for idx, row in tqdm(gdf.iterrows(), total=len(gdf), desc="Extraction mensuelle"):
        dt = row['date']

        try:
            # ---- MAX ----
            f_max = get_monthly_netcdf_path(dt, base_folder, "max")
            t_vals, s_vals, _, _ = get_ws_vel_for_poly(row['geometry'], f_max)

            t_valid = [x for x in t_vals if not np.isnan(x)]
            s_valid = [x for x in s_vals if not np.isnan(x)]

            ws_max1m.append(np.max(t_valid) if t_valid else np.nan)
            vel_max1m.append(np.max(s_valid) if s_valid else np.nan)

            # ---- MIN ----
            f_min = get_monthly_netcdf_path(dt, base_folder, "min")
            t_vals, s_vals, _, _ = get_ws_vel_for_poly(row['geometry'], f_min)

            t_valid = [x for x in t_vals if not np.isnan(x)]
            s_valid = [x for x in s_vals if not np.isnan(x)]

            ws_min1m.append(np.min(t_valid) if t_valid else np.nan)
            vel_min1m.append(np.min(s_valid) if s_valid else np.nan)

            # ---- MEAN ----
            f_mean = get_monthly_netcdf_path(dt, base_folder, "mean")
            _, _, t_mean, s_mean = get_ws_vel_for_poly(row['geometry'], f_mean)
            ws_mean1m.append(t_mean)
            vel_mean1m.append(s_mean)

        except FileNotFoundError as e:
            print(f"⚠️ Fichier manquant pour {dt}: {e}")
            ws_max1m.append(np.nan)
            ws_min1m.append(np.nan)
            ws_mean1m.append(np.nan)
            vel_max1m.append(np.nan)
            vel_min1m.append(np.nan)
            vel_mean1m.append(np.nan)

    # ---- Sauvegarde ----
    gdf['ws_max_surface'] = ws_max1m
    gdf['ws_min_surface'] = ws_min1m
    gdf['ws_mean_surface'] = ws_mean1m
    gdf['vel_max_surface'] = vel_max1m
    gdf['vel_min_surface'] = vel_min1m
    gdf['vel_mean_surface'] = vel_mean1m

    gdf.to_file("grille_med_est.geojson", driver="GeoJSON")


if __name__ == "__main__":
    main()
