In [None]:
from tqdm import tqdm
import geopandas as gpd
import xarray as xr
import rioxarray
import numpy as np
from shapely.geometry import Point, box
import os
from datetime import timedelta
from dateutil.relativedelta import relativedelta


# -------------------------------------------------------------
# 1. Récupération du fichier mensuel du mois précédent
# -------------------------------------------------------------
def get_monthly_netcdf_path(dt, base_folder, stat_type):
    """
    Retourne le fichier mensuel correspondant au mois précédent.
    Exemple : dt = 2019-07-01 → récupère fichier de juin 2019.
    """
    prev_month = dt - relativedelta(months=1)
    yyyymm = prev_month.strftime("%Y%m")

    year_folder = os.path.join(base_folder, str(prev_month.year))
    fname = f"MARS3D_{yyyymm}_{stat_type}.nc"
    fpath = os.path.join(year_folder, fname)

    if not os.path.exists(fpath):
        raise FileNotFoundError(f"Fichier introuvable : {fpath}")

    return fpath


# -------------------------------------------------------------
# 2. Extraction WS / VEL pour un polygone
# -------------------------------------------------------------
def get_ws_vel_for_poly(poly, ncdf_path, n_closest=3):
    """
    Extraction WS/VEL pour un polygone et un fichier NetCDF mensuel.
    Moyenne pondérée par fraction de pixel intersecté.
    Fallback : prendre les n_closest pixels les plus proches qui ont des valeurs NON-NaN.
    """
    ds = xr.open_dataset(ncdf_path, engine="netcdf4")
    if not hasattr(ds, 'crs'):
        ds = ds.rio.write_crs("EPSG:4326")

    ws = ds['WINDSTRESS'].isel(time=0)            # 2D (y,x)
    vel = ds['VELOCITY'].isel(time=0, level=-1)   # 2D (y,x)

    poly_gs = gpd.GeoSeries([poly], crs="EPSG:4326")
    poly_proj = poly_gs.to_crs(ws.rio.crs).iloc[0]

    transform = ws.rio.transform()
    height, width = ws.shape

    coords = []
    weights = []

    # ---- 1) Recherche des pixels intersectés ----
    for j in range(height):
        for i in range(width):
            x_min, y_max = transform * (i, j)
            x_max, y_min = transform * (i + 1, j + 1)
            pixel_poly = box(x_min, y_min, x_max, y_max)

            intersection = poly_proj.intersection(pixel_poly)
            if not intersection.is_empty:
                coords.append((j, i))
                weights.append(intersection.area / pixel_poly.area)

    # Helper : cherche et renvoie jusqu'à n_closest pixels les plus proches AVEC ws non-NaN
    def find_n_closest_with_valid_ws(n):
        dlist = []
        for jj in range(height):
            for ii in range(width):
                wval = ws.values[jj, ii]
                if np.isnan(wval):
                    continue
                x_c, y_c = transform * (ii + 0.5, jj + 0.5)
                d = Point(x_c, y_c).distance(poly_proj)
                dlist.append((d, jj, ii))
        if not dlist:
            return [], []
        dlist.sort(key=lambda x: x[0])
        chosen = dlist[:n]
        coords_fb = [(jj, ii) for _, jj, ii in chosen]
        weights_fb = [1.0] * len(coords_fb)  # uniform weight fallback
        return coords_fb, weights_fb

    # ---- 2) Si pas d'intersection, fallback vers n_closest valides ----
    if len(coords) == 0:
        coords, weights = find_n_closest_with_valid_ws(n_closest)

    # ---- 3) Extraction des valeurs en synchronisant poids (on garde poids seulement si ws valide) ----
    ws_vals = []
    vel_vals = []
    valid_weights = []

    for wgt, (j, i) in zip(weights, coords):
        w = ws.values[j, i]
        v = vel.values[j, i]
        if not np.isnan(w):
            ws_vals.append(w)
            vel_vals.append(v if not np.isnan(v) else np.nan)
            valid_weights.append(wgt)

    # ---- 4) Si les intersectés existent mais tous NaN, on fait fallback ciblé sur valeurs valides ----
    if len(ws_vals) == 0:
        coords_fb, weights_fb = find_n_closest_with_valid_ws(n_closest)
        if coords_fb:
            ws_vals = []
            vel_vals = []
            valid_weights = []
            for wgt, (j, i) in zip(weights_fb, coords_fb):
                w = ws.values[j, i]
                v = vel.values[j, i]
                # ici on a déjà filtré ws non-NaN dans find_n_closest_with_valid_ws,
                # mais on vérifie quand même pour être safe
                if not np.isnan(w):
                    ws_vals.append(w)
                    vel_vals.append(v if not np.isnan(v) else np.nan)
                    valid_weights.append(wgt)

    ds.close()

    # ---- 5) Si toujours vide, renvoyer NaN ----
    if len(ws_vals) == 0:
        return [], [], np.nan, np.nan

    ws_vals = np.array(ws_vals)
    vel_vals = np.array(vel_vals)
    weights_arr = np.array(valid_weights)

    ws_mean = np.nansum(ws_vals * weights_arr) / np.nansum(weights_arr)
    vel_mean = np.nansum(vel_vals * weights_arr) / np.nansum(weights_arr)

    return ws_vals.tolist(), vel_vals.tolist(), ws_mean, vel_mean


# -------------------------------------------------------------
# 3. Programme principal
# -------------------------------------------------------------
def main():

    base_folder =   "/home/paulinev/Bureau/Marbec_data/BiodivMed/MARS3D/Med_MENOR/Aggregated/CUR-WIND_latlon/Monthly/Med-Est"

    gdf = gpd.read_file("grille_med_est.geojson")

    ws_max1m, ws_min1m, ws_mean1m = [], [], []
    vel_max1m, vel_min1m, vel_mean1m = [], [], []

    for idx, row in tqdm(gdf.iterrows(), total=len(gdf), desc="Extraction mensuelle"):
        dt = row['date']

        try:
            # ---- MAX ----
            f_max = get_monthly_netcdf_path(dt, base_folder, "max")
            t_vals, s_vals, _, _ = get_ws_vel_for_poly(row['geometry'], f_max)
            ws_max1m.append(np.nanmax(t_vals) if len(t_vals) else np.nan)
            vel_max1m.append(np.nanmax(s_vals) if len(s_vals) else np.nan)

            # ---- MIN ----
            f_min = get_monthly_netcdf_path(dt, base_folder, "min")
            t_vals, s_vals, _, _ = get_ws_vel_for_poly(row['geometry'], f_min)
            ws_min1m.append(np.nanmin(t_vals) if len(t_vals) else np.nan)
            vel_min1m.append(np.nanmin(s_vals) if len(s_vals) else np.nan)

            # ---- MEAN ----
            f_mean = get_monthly_netcdf_path(dt, base_folder, "mean")
            _, _, t_mean, s_mean = get_ws_vel_for_poly(row['geometry'], f_mean)
            ws_mean1m.append(t_mean)
            vel_mean1m.append(s_mean)

        except FileNotFoundError as e:
            print(f"⚠️ Fichier manquant pour {dt}: {e}")
            ws_max1m.append(np.nan)
            ws_min1m.append(np.nan)
            ws_mean1m.append(np.nan)
            vel_max1m.append(np.nan)
            vel_min1m.append(np.nan)
            vel_mean1m.append(np.nan)

    # ---- Sauvegarde ----
    gdf['ws_max_surface'] = ws_max1m
    gdf['ws_min_surface'] = ws_min1m
    gdf['ws_mean_surface'] = ws_mean1m
    gdf['vel_max_surface'] = vel_max1m
    gdf['vel_min_surface'] = vel_min1m
    gdf['vel_mean_surface'] = vel_mean1m

    gdf.to_file("grille_med_est.geojson", driver="GeoJSON")


if __name__ == "__main__":
    main()
