##  Ocean Enrichment for Telemetry + ERA5
 
 This notebook:
 1. Loads telemetry already enriched with ERA5
 2. Determines time span and monthly bounding boxes
 3. Downloads Copernicus Marine ocean data (uo, vo, thetao, so) per month
 4. Interpolates ocean data to the ship track (time + space)
 5. Writes monthly parquet files
 6. Merges all ocean fields back into the telemetry+ERA5 dataset
 7. Computes surface density from thetao & salinity and saves a final parquet

##  --- 0. PREAMBLE & CONFIG ---
 - Set paths
 - Define Copernicus Marine dataset and variables
 - Simple logger

In [None]:
import os
import time
from glob import glob

import numpy as np
import pandas as pd
import xarray as xr

from copernicusmarine import subset as cm_subset  # Copernicus Marine subset API

# Paths
TELEMETRY_PATH = "metocean_out/telemetry_with_era5_wind_waves.parquet"

OUT_DIR_OCEAN  = "./test"
OCEAN_TMP_DIR  = "./_ocean_tmp"
os.makedirs(OUT_DIR_OCEAN, exist_ok=True)
os.makedirs(OCEAN_TMP_DIR, exist_ok=True)

# Copernicus Marine dataset + variables
OCEAN_DATASET_ID = "cmems_mod_glo_phy_anfc_0.083deg_PT1H-m"  # hourly global physics
OCEAN_VARS = ["uo", "vo", "thetao", "so"]  # surface u, v, potential T, salinity


def log(msg: str) -> None:
    """Simple timestamped logger."""
    print(time.strftime("%H:%M:%S"), "-", msg)

##  --- 1. LOAD TELEMETRY & SCOPE ---
 - Load telemetry+ERA5 parquet
 - Ensure timestamp is UTC-aware
 - Add `row_id` as stable key
 - Build minimal DF for ocean interpolation
 - Compute overall time span
 - Define global bbox (info only) and monthly bbox helper
 - Build list of months to process


In [None]:
df = pd.read_parquet(TELEMETRY_PATH)
df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)
df = df.reset_index(drop=True)
df["row_id"] = np.arange(len(df), dtype=np.int64)

cols_needed = ["row_id", "timestamp", "Latitude_deg", "Longitude_deg"]
df_ocean = df[cols_needed].copy()

t_min = df_ocean["timestamp"].min()
t_max = df_ocean["timestamp"].max()
log(f"time span: {t_min} → {t_max}")


def bbox_quantile(d: pd.DataFrame, qlow=0.01, qhigh=0.99, pad=2.0):
    """
    Global bounding box using quantiles to avoid outliers, padded by `pad` degrees.
    Returns [N, W, S, E].
    """
    lat_min = float(d["Latitude_deg"].quantile(qlow))
    lat_max = float(d["Latitude_deg"].quantile(qhigh))
    lon_min = float(d["Longitude_deg"].quantile(qlow))
    lon_max = float(d["Longitude_deg"].quantile(qhigh))

    N = min(lat_max + pad, 90.0)
    S = max(lat_min - pad, -90.0)
    W = max(lon_min - pad, -180.0)
    E = min(lon_max + pad, 180.0)
    return [N, W, S, E]


def month_bbox(df_m: pd.DataFrame, pad=2.0):
    """
    Compute a tight bounding box for a single month of telemetry,
    then pad it by `pad` degrees in all directions.
    Returns [N, W, S, E] for Copernicus.
    """
    lat_min = float(df_m["Latitude_deg"].min())
    lat_max = float(df_m["Latitude_deg"].max())
    lon_min = float(df_m["Longitude_deg"].min())
    lon_max = float(df_m["Longitude_deg"].max())

    N = min(lat_max + pad, 90.0)
    S = max(lat_min - pad, -90.0)
    W = max(lon_min - pad, -180.0)
    E = min(lon_max + pad, 180.0)
    return [N, W, S, E]  # [north, west, south, east]


def month_list(t0, t1):
    """Return list of YYYY-MM strings from t0 to t1 inclusive."""
    cur = pd.Timestamp(t0).to_period("M")
    last = pd.Timestamp(t1).to_period("M")
    out = []
    while cur <= last:
        out.append(str(cur))
        cur = cur + 1
    return out


# Global bbox (for reference)
area_global = bbox_quantile(df_ocean, 0.01, 0.99, 2.0)
log(f"OCEAN bbox [N,W,S,E]: {area_global}")

# Months to process
months = month_list(t_min, t_max)
log(f"months to fetch: {months[:3]} ... {months[-3:]} (total {len(months)})")

# tz-naive helper for interpolation (xarray)
if "ts_naive" not in df_ocean.columns:
    df_ocean = df_ocean.copy()
    df_ocean["ts_naive"] = df_ocean["timestamp"].dt.tz_convert("UTC").dt.tz_localize(None)



##  --- 2. DOWNLOAD ONE MONTH OF OCEAN DATA (Copernicus Marine) ---
 - Use Copernicus Marine `subset` API
 - Monthly time window
 - Monthly bbox
 - Save to NetCDF per month

In [None]:
def ocean_download_month(year_month: str, area, vars_list, out_nc_path: str) -> str:
    """
    Download one month of ocean data via Copernicus Marine `subset` API.
    `area` is [N, W, S, E] for THIS MONTH ONLY.
    """
    if os.path.isfile(out_nc_path):
        log(f"[skip] ocean exists: {out_nc_path}")
        return out_nc_path

    year, month = map(int, year_month.split("-"))

    # Monthly time span [t0, t1)
    t0 = pd.Timestamp(year=year, month=month, day=1, tz="UTC")
    if month == 12:
        t1 = pd.Timestamp(year=year + 1, month=1, day=1, tz="UTC")
    else:
        t1 = pd.Timestamp(year=year, month=month + 1, day=1, tz="UTC")

    north, west, south, east = area  # [N, W, S, E]

    log(f"[dl-ocean] {year_month} bbox={area} → {out_nc_path}")
    cm_subset(
        dataset_id=OCEAN_DATASET_ID,
        variables=vars_list,
        start_datetime=t0.isoformat(),
        end_datetime=t1.isoformat(),
        minimum_latitude=south,
        maximum_latitude=north,
        minimum_longitude=west,
        maximum_longitude=east,
        output_filename=out_nc_path,
        overwrite=False,
    )
    return out_nc_path


##  --- 3. OPEN & INTERPOLATE ONE MONTH TO THE SHIP TRACK ---
 - Open monthly NetCDF with chunking
 - Keep only surface layer (depth=0) if depth exists
 - Wrap longitudes to [0, 360)
 - Map each ship time to nearest model time (with tolerance)
 - Interpolate spatially (linear + nearest fallback) onto ship track

In [None]:
def open_ocean_nc(path: str) -> xr.Dataset:
    """
    Open NetCDF with chunking and keep only surface (depth=0) if depth exists.
    Remove depth dimension/coordinate when present.
    """
    ds = xr.open_dataset(
        path,
        chunks={"time": 24, "latitude": 200, "longitude": 200},
    )

    # Case A – depth is a full dimension: pick surface & drop dim
    if "depth" in ds.dims:
        ds = ds.isel(depth=0, drop=True)
    # Case B – depth is only a coordinate
    elif "depth" in ds.coords:
        ds = ds.drop_vars("depth", errors="ignore")

    return ds


def wrap_lon_360(ds: xr.Dataset) -> xr.Dataset:
    """
    Ensure longitude in [0, 360), sorted.
    """
    if "longitude" not in ds.coords:
        return ds
    lon = ds["longitude"]
    if float(lon.max()) <= 180.0:
        ds = ds.assign_coords(longitude=((lon + 360) % 360)).sortby("longitude")
    return ds


def ocean_interp_time_space(ds: xr.Dataset, df_m: pd.DataFrame, time_tol: str = "90min") -> pd.DataFrame:
    """
    Interpolate ocean ds (uo, vo, thetao, so) to ship positions+times.

    Parameters
    ----------
    ds : xarray.Dataset
        Must have dims (time, latitude, longitude)
    df_m : pandas.DataFrame
        With columns: ['row_id', 'timestamp', 'ts_naive', 'Latitude_deg', 'Longitude_deg']
    time_tol : str
        Max allowed time distance between ship time and model time (e.g. '90min').

    Returns
    -------
    pandas.DataFrame
        Index=row_id, columns=subset of [uo, vo, thetao, so]
    """
    ds = wrap_lon_360(ds)
    if "time" not in ds.coords:
        raise ValueError("Ocean ds has no 'time' coord.")

    # 1) Map each ship timestamp to nearest model time index
    t_mod = ds["time"].values  # model times
    t_req = df_m["ts_naive"].to_numpy().astype("datetime64[ns]")

    t_mod_i64 = t_mod.astype("datetime64[ns]").astype("int64")
    t_req_i64 = t_req.astype("int64")

    idx_right = np.searchsorted(t_mod_i64, t_req_i64)
    idx_left  = np.clip(idx_right - 1, 0, len(t_mod_i64) - 1)
    idx_right = np.clip(idx_right,      0, len(t_mod_i64) - 1)

    diff_left  = np.abs(t_mod_i64[idx_left]  - t_req_i64)
    diff_right = np.abs(t_mod_i64[idx_right] - t_req_i64)
    nearest_idx = np.where(diff_right < diff_left, idx_right, idx_left)

    # Apply time tolerance
    tol_ns = int(pd.Timedelta(time_tol).to_numpy())  # nanoseconds
    diff_ns = np.abs(t_mod_i64[nearest_idx] - t_req_i64)
    valid = diff_ns <= tol_ns

    df_valid = df_m.loc[valid].copy()
    df_valid["time_idx"] = nearest_idx[valid]

    # 2) Output frame
    vars_ = [v for v in ["uo", "vo", "thetao", "so"] if v in ds.data_vars]
    out = pd.DataFrame(index=df_m["row_id"].values, columns=vars_, dtype="float32")

    # 3) Loop per unique model time index
    for ti in np.unique(df_valid["time_idx"].values):
        df_chunk = df_valid[df_valid["time_idx"] == ti]
        if df_chunk.empty:
            continue

        t_val = t_mod[ti]
        ds_slice = ds.sel(time=t_val)

        lon360 = ((df_chunk["Longitude_deg"].to_numpy()) + 360.0) % 360.0
        lat    = df_chunk["Latitude_deg"].to_numpy()

        try:
            primary = ds_slice.interp(
                longitude=("p", lon360),
                latitude=("p", lat),
                method="linear",
            )
            nearest = ds_slice.interp(
                longitude=("p", lon360),
                latitude=("p", lat),
                method="nearest",
            )
        except Exception as e:
            log(f"[interp-ocean] warning: time {t_val} failed: {e}")
            continue

        row_ids_chunk = df_chunk["row_id"].values
        for v in vars_:
            a = primary[v].values
            m = np.isnan(a)
            if m.any():
                a[m] = nearest[v].values[m]
            out.loc[row_ids_chunk, v] = a.astype("float32")

    out.index.name = "row_id"
    return out


##  --- 4. LOOP OVER MONTHS: DOWNLOAD → INTERP → WRITE ---
 - For each month:
   - Subset telemetry rows
   - Compute month-specific bbox
   - Download Copernicus subset
   - Interpolate onto ship track
   - Write monthly parquet: `ocean_interp_YYYY-MM.parquet`
 - Skip months already processed

In [None]:
import traceback

def process_ocean_all_months(df_all: pd.DataFrame, months, vars_list, out_dir, tmp_dir, pad: float = 2.0):
    """
    Loop over months, compute a tight bbox for EACH MONTH from telemetry,
    download Copernicus subset, interpolate to track, and write parquet.
    """
    out_parts = []

    for ym in months:
        ds = None
        nc_path = os.path.join(tmp_dir, f"OCEAN_{ym}.nc")

        try:
            # Subset telemetry for this month
            mask = df_all["ts_naive"].dt.to_period("M") == pd.Period(ym)
            df_m = df_all.loc[mask, ["row_id", "timestamp", "ts_naive", "Latitude_deg", "Longitude_deg"]].copy()
            if df_m.empty:
                log(f"[skip-ocean] {ym}: no rows in this month")
                continue

            # Month-specific bbox
            area_m = month_bbox(df_m, pad=pad)   # [N, W, S, E]
            log(f"[dl-ocean] {ym} bbox={area_m} → {nc_path}")

            # Download (or reuse)
            nc_path = ocean_download_month(ym, area_m, vars_list, nc_path)

            # Open and inspect
            log(f"[open-ocean] {ym}: {nc_path}")
            ds = open_ocean_nc(nc_path)
            try:
                log(f"[info-ocean] {ym}: vars={list(ds.data_vars)[:8]} ...")
                log(f"[info-ocean] {ym}: coords={list(ds.coords)}")
            except Exception:
                pass

            # Interpolate
            log(f"[interp-ocean] {ym} for {len(df_m)} rows …")
            df_o = ocean_interp_time_space(ds, df_m, time_tol="90min")

            # Write parquet
            part_path = os.path.join(out_dir, f"ocean_interp_{ym}.parquet")
            df_o.to_parquet(part_path)
            out_parts.append(part_path)
            log(f"[done-ocean] {ym}: wrote {part_path} (cols={list(df_o.columns)})")

        except Exception as e:
            log(f"[ERROR-ocean] {ym}: {e}")
            traceback.print_exc()

        finally:
            if ds is not None:
                try:
                    ds.close()
                except Exception:
                    pass
            # If you want, you can clean nc_path here to save disk.

    return out_parts


_done = {
    os.path.basename(p)[len("ocean_interp_"):-len(".parquet")]
    for p in glob(os.path.join(OUT_DIR_OCEAN, "ocean_interp_*.parquet"))
}
months_todo = [m for m in months if m not in _done]
log(f"Remaining OCEAN months: {months_todo}")

ocean_parts = process_ocean_all_months(
    df_all   = df_ocean,
    months   = months_todo,
    vars_list= OCEAN_VARS,
    out_dir  = OUT_DIR_OCEAN,
    tmp_dir  = OCEAN_TMP_DIR,
    pad      = 2.0,  # same padding as global, but per-month
)

log(f"Ocean monthly parts written: {len(ocean_parts)}")

##  --- 5. MERGE OCEAN ONTO TELEMETRY+ERA5 ---
 - Concatenate all monthly ocean parquet files
 - Merge back onto original df via `row_id`
 - Save an intermediate parquet with ERA5 + ocean

In [None]:
part_files = sorted(glob(os.path.join(OUT_DIR_OCEAN, "ocean_interp_*.parquet")))
if part_files:
    ocean_all = pd.concat([pd.read_parquet(p) for p in part_files]).sort_index()
    log(f"ocean rows: {len(ocean_all)} | columns: {list(ocean_all.columns)}")

    df_full = df.merge(ocean_all, on="row_id", how="left")

    FINAL_OCEAN = os.path.join(OUT_DIR_OCEAN, "telemetry_with_era5_ocean.parquet")
    df_full.to_parquet(FINAL_OCEAN, index=False)
    log(f"saved: {FINAL_OCEAN}")
else:
    log("No ocean parquet parts found; nothing merged.")
    df_full = None  # to avoid NameError later if nothing is merged

##  --- 6. COMPUTE SURFACE DENSITY (GSW) & SAVE FINAL DATASET ---
 - Load the telemetry+ERA5+ocean file (or use df_full)
 - Use GSW to compute in-situ density at surface (p ≈ 0 dbar)
 - Save final dataset with `rho_surface`

In [None]:
import gsw

if df_full is None:
    # if notebook resumed and df_full not in memory
    FINAL_OCEAN = os.path.join(OUT_DIR_OCEAN, "telemetry_with_era5_ocean.parquet")
    df_full = pd.read_parquet(FINAL_OCEAN)

T = df_full["thetao"].to_numpy()
S = df_full["so"].to_numpy()
p = np.zeros_like(T)  # surface pressure ~ 0 dbar

rho = gsw.rho(S, T, p)  # kg/m^3
df_full["rho_surface"] = rho

FINAL_DENS = os.path.join(OUT_DIR_OCEAN, "telemetry_with_era5_ocean_density.parquet")
df_full.to_parquet(FINAL_DENS, index=False)
log(f"saved: {FINAL_DENS} with rho_surface")


In [None]:
import pandas as pd

final_path = "test/telemetry_with_era5_ocean_density.parquet"
df_final = pd.read_parquet(final_path)

print("Final shape:", df_final.shape)

print(df_final[[
    "timestamp",
    "Latitude_deg",
    "Longitude_deg",
    "u10", "v10",
    "uo", "vo",
    "thetao", "so",
    "rho_surface"
]].head())

print("\nNaN fractions:")
for c in ["u10", "v10", "uo", "vo", "thetao", "so", "rho_surface"]:
    if c in df_final.columns:
        print(f"{c:12s}: {df_final[c].isna().mean():6.3f}")
