# # ERA5 Enrichment for Telemetry data 
 
This notebook:
1. Loads telemetry from parquet
2. Determines the time range and spatial bounding boxes
3. Downloads ERA5 wind & wave data per month (with tight monthly bboxes)
4. Interpolates ERA5 data to the ship track (time + space, nearest)
5. Writes monthly parquet files
6. Merges all ERA5 fields back into the telemetry dataset and saves a final parquet

##  --- 0. PREAMBLE & CONFIG ---
- Set paths
- Define ERA5 variables and dataset name
- Tiny logger


In [46]:
import os
import time
from glob import glob
import traceback
import zipfile
import tempfile

import numpy as np
import pandas as pd
import xarray as xr
import cdsapi

# Paths (edit if needed)
TELEMETRY_PATH = "./data/TraviataDataForTesting.parquet"  # input telemetry
OUT_DIR        = "./metocean_out"                         # monthly ERA5 outputs + final parquet
ERA5_TMP_DIR   = "./_era5_tmp"                            # temporary ERA5 monthly files

os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(ERA5_TMP_DIR, exist_ok=True)

# ERA5 variables (wind + waves) from single levels
ERA5_VARS = [
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "significant_height_of_combined_wind_waves_and_swell",
    "mean_wave_period",
    "mean_wave_direction",
]

ERA5_DATASET = "reanalysis-era5-single-levels"


def log(msg: str) -> None:
    """Simple timestamped logger."""
    print(time.strftime("%H:%M:%S"), "-", msg)

## --- 1. LOAD TELEMETRY & SCOPE ---
 - Load ship telemetry parquet
 - Ensure `timestamp` is tz-aware (UTC)
 - Add a `row_id` for stable merging
 - Build a minimal dataframe for ERA5 interpolation (`df_era`)
 - Compute overall time span
 - Define:
   - a global bbox (for reference)
   - a helper to compute monthly bboxes
   - the list of months to process


In [None]:
# Load telemetry
df = pd.read_parquet(TELEMETRY_PATH)

# Ensure timestamp is UTC-aware
df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True)

# Stable key for merging
df = df.reset_index(drop=True)
df["row_id"] = np.arange(len(df), dtype=np.int64)

# Minimal subset for ERA5 interpolation
cols_needed = ["row_id", "timestamp", "Latitude_deg", "Longitude_deg", "HeadingTrue_deg"]
df_era = df[cols_needed].copy()

# Time span
t_min = df_era["timestamp"].min()
t_max = df_era["timestamp"].max()
log(f"time span: {t_min} → {t_max}")


def bbox_quantile(d: pd.DataFrame, qlow=0.01, qhigh=0.99, pad=2.0):
    """
    Global bounding box using quantiles to avoid outliers, padded by `pad` degrees.
    Returns [N, W, S, E] for CDS.
    """
    lat_min = float(d["Latitude_deg"].quantile(qlow))
    lat_max = float(d["Latitude_deg"].quantile(qhigh))
    lon_min = float(d["Longitude_deg"].quantile(qlow))
    lon_max = float(d["Longitude_deg"].quantile(qhigh))

    N = min(lat_max + pad, 90.0)
    S = max(lat_min - pad, -90.0)
    W = max(lon_min - pad, -180.0)
    E = min(lon_max + pad, 180.0)
    return [N, W, S, E]


def month_bbox(df_m: pd.DataFrame, pad=2.0):
    """
    Compute a tight bounding box for a single month's telemetry,
    then pad it by `pad` degrees. Returns [N, W, S, E].
    """
    lat_min = float(df_m["Latitude_deg"].min())
    lat_max = float(df_m["Latitude_deg"].max())
    lon_min = float(df_m["Longitude_deg"].min())
    lon_max = float(df_m["Longitude_deg"].max())

    N = min(lat_max + pad, 90.0)
    S = max(lat_min - pad, -90.0)
    W = max(lon_min - pad, -180.0)
    E = min(lon_max + pad, 180.0)
    return [N, W, S, E]


def month_list(t0, t1):
    """Return list of YYYY-MM strings from t0 to t1 inclusive."""
    cur = pd.Timestamp(t0).to_period("M")
    last = pd.Timestamp(t1).to_period("M")
    out = []
    while cur <= last:
        out.append(str(cur))
        cur = cur + 1
    return out


# Global bbox (mainly informative)
area_global = bbox_quantile(df_era, 0.01, 0.99, 2.0)
log(f"Global ERA5 bbox [N,W,S,E]: {area_global}")

# List of months to process
months = month_list(t_min, t_max)
log(f"months to fetch: {months[:3]} ... {months[-3:]}  (total {len(months)})")

13:05:00 - time span: 2023-01-01 03:00:00+00:00 → 2025-07-24 23:59:00+00:00
13:05:00 - ERA5 bbox [N,W,S,E]: [59.64501496266664, -171.4044147, -39.30653573066667, 173.66493906666665]
13:05:00 - months to fetch: ['2023-01', '2023-02', '2023-03'] ... ['2025-05', '2025-06', '2025-07']  (total 31)


  cur = pd.Timestamp(t0).to_period("M")
  last = pd.Timestamp(t1).to_period("M")


## --- 2. ERA5 ONE-MONTH DOWNLOADER ---
- Use CDS API to download one month
- We request NetCDF, but CDS may return ZIP → detect and rename
- Function returns path to either `.nc` or `.zip` (caller handles both)

In [None]:
def era5_download_month(year_month: str, area, variables, out_path_nc):
    """
    Download one month of ERA5 single-level data.
    If CDS returns a ZIP, rename it to .zip and return that path.
    Otherwise return the .nc path.

    Parameters
    ----------
    year_month : str
        'YYYY-MM'
    area : list[float]
        [N, W, S, E] bounding box
    variables : list[str]
        ERA5 variable names
    out_path_nc : str
        Target path for NetCDF (we may rename to .zip)

    Returns
    -------
    str
        Path to downloaded .nc or .zip file.
    """
    year, month = year_month.split("-")
    days = [f"{d:02d}" for d in range(1, 32)]
    hours = [f"{h:02d}:00" for h in range(24)]

    out_base = os.path.splitext(out_path_nc)[0]
    out_path_zip = out_base + ".zip"

    # Skip if already present
    if os.path.isfile(out_path_nc):
        log(f"[skip] exists: {out_path_nc}")
        return out_path_nc
    if os.path.isfile(out_path_zip):
        log(f"[skip] exists: {out_path_zip}")
        return out_path_zip

    req = {
        "product_type": "reanalysis",
        "variable": variables,
        "year": year,
        "month": month,
        "day": days,
        "time": hours,
        "data_format": "netcdf",
        "download_format": "unarchived",  # sometimes ignored by CDS
        "area": area,  # [N, W, S, E]
    }

    tmp_path = out_path_nc
    log(f"[dl] {year_month} → {tmp_path}")
    c = cdsapi.Client()
    c.retrieve(ERA5_DATASET, req, tmp_path)

    # Detect if CDS actually gave us a zip
    if zipfile.is_zipfile(tmp_path):
        new_path = out_path_zip
        os.rename(tmp_path, new_path)
        log(f"[note] CDS returned ZIP → renamed to {new_path}")
        return new_path
    else:
        return tmp_path

## --- 3. OPEN & INTERPOLATE ONE MONTH ---

 - Robust NetCDF opener
 - Standardize ERA5 dataset (rename variables, resolve dims, sort)
 - Longitude wrapping to [0, 360)
 - Chunked interpolation:
   - Nearest in time (with tolerance)
   - Nearest in space
 - Support for both:
   - Plain `.nc` files
   - ZIP files containing one or more `.nc` members

In [None]:
def open_nc(path: str) -> xr.Dataset:
    """
    Open a NetCDF file with a sequence of engines for robustness.
    """
    engines = ("netcdf4", "h5netcdf", "scipy")
    for eng in engines:
        try:
            return xr.open_dataset(path, engine=eng)
        except Exception:
            pass
    # Last resort (let xarray decide)
    return xr.open_dataset(path)


def standardize_ds(ds: xr.Dataset) -> xr.Dataset:
    """
    Make ERA5 consistent for interpolation:
    - valid_time -> time
    - pick one ensemble member (if present)
    - collapse expver dimension
    - sort coordinates
    - rename key variables to short names
    """
    if "valid_time" in ds.coords and "time" not in ds.coords:
        ds = ds.rename({"valid_time": "time"})
    if "number" in ds.dims:
        ds = ds.isel(number=0, drop=True)
    if "expver" in ds.dims:
        ds = ds.max("expver", skipna=True, keep_attrs=True)

    for c in ("time", "latitude", "longitude"):
        if c in ds.coords:
            ds = ds.sortby(c)

    rename = {
        "10m_u_component_of_wind": "u10",
        "10m_v_component_of_wind": "v10",
        "significant_height_of_combined_wind_waves_and_swell": "swh",
        "mean_wave_period": "mwp",
        "mean_wave_direction": "mwd",
    }
    ds = ds.rename({k: v for k, v in rename.items() if k in ds.data_vars})
    return ds


def wrap_lon_to(ds: xr.Dataset, mode: str = "360") -> xr.Dataset:
    """
    Ensure dataset longitudes match telemetry convention.

    mode="360"  → [0, 360)
    mode="-180" → [-180, 180)
    """
    if "longitude" not in ds.coords:
        return ds

    lon = ds.longitude
    if mode == "360":
        if float(lon.max()) <= 180.0:
            ds = ds.assign_coords(longitude=((lon + 360) % 360)).sortby("longitude")
    else:
        if float(lon.max()) > 180.0:
            ds = ds.assign_coords(longitude=((lon + 180) % 360) - 180).sortby("longitude")
    return ds


def interp_time_then_space_chunked_nearest(
    ds: xr.Dataset,
    df_m: pd.DataFrame,
    ym: str,
    time_tol="90min",
    batch_size=10_000,
) -> pd.DataFrame:
    """
    Chunked + NEAREST interpolation of ERA5 onto ship track for a single month.

    Steps:
    - For each batch of telemetry rows:
      1) Nearest-in-time (with tolerance)
      2) Nearest in space (lat/lon)
    - Return a DataFrame indexed by row_id with interpolated variables.
    """
    t_all   = df_m["ts_naive"].to_numpy().astype("datetime64[ns]")
    lon_all = ((df_m["Longitude_deg"].to_numpy()) + 360.0) % 360.0
    lat_all = df_m["Latitude_deg"].to_numpy()
    row_all = df_m["row_id"].to_numpy()

    n = len(df_m)
    vars_ = [v for v in ["u10", "v10", "swh", "mwp", "mwd"] if v in ds.data_vars]

    out_chunks = []
    tol = np.timedelta64(pd.Timedelta(time_tol))

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        log(f"[{ym}] interp batch {start}:{end} of {n}")

        t_chunk   = t_all[start:end]
        lon_chunk = lon_all[start:end]
        lat_chunk = lat_all[start:end]
        row_chunk = row_all[start:end]

        # Nearest in time (with tolerance)
        ds_t = ds.sel(
            time=xr.DataArray(t_chunk, dims="p"),
            method="nearest",
            tolerance=tol,
        )

        # Nearest in space
        ds_nearest = ds_t.interp(
            longitude=("p", lon_chunk),
            latitude=("p", lat_chunk),
            method="nearest",
        )

        df_out_batch = pd.DataFrame(index=row_chunk)
        for v in vars_:
            df_out_batch[v] = ds_nearest[v].values

        df_out_batch.index.name = "row_id"
        out_chunks.append(df_out_batch)

    out = pd.concat(out_chunks).sort_index()
    return out


def interpolate_zip_by_member_robust(zip_path: str, df_m: pd.DataFrame, ym: str) -> pd.DataFrame:
    """
    Open a CDS ZIP, interpolate EACH .nc member separately (on its native grid)
    using chunked nearest interpolation, then combine on row_id.
    """
    results = []
    with zipfile.ZipFile(zip_path) as z:
        members = [m for m in z.namelist() if m.endswith(".nc")]
        if not members:
            raise FileNotFoundError(f"No .nc inside zip: {zip_path}")

        for m in members:
            with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp:
                tmp.write(z.read(m))
                one_nc_path = tmp.name

            try:
                ds_one = open_nc(one_nc_path)
                ds_one = standardize_ds(ds_one)
                ds_one = wrap_lon_to(ds_one, "360")

                has = [v for v in ["u10", "v10", "swh", "mwp", "mwd"] if v in ds_one.data_vars]
                log(f"[member] {ym}:{m} provides: {has}")

                df_chunk = interp_time_then_space_chunked_nearest(
                    ds_one, df_m, ym, time_tol="90min", batch_size=10_000
                )
                if not df_chunk.empty:
                    results.append(df_chunk)

            finally:
                try:
                    os.remove(one_nc_path)
                except Exception:
                    pass

    if not results:
        raise ValueError(f"{ym}: no variables were interpolated from ZIP members.")
    out = results[0]
    for r in results[1:]:
        out = out.combine_first(r)
    return out


def interpolate_one_file_robust(ds: xr.Dataset, df_m: pd.DataFrame, ym: str) -> pd.DataFrame:
    """
    Robust interpolation for a single plain .nc using the same
    chunked + nearest strategy.
    """
    ds = standardize_ds(ds)
    ds = wrap_lon_to(ds, "360")
    if "time" not in ds.coords:
        raise ValueError(f"{ym}: dataset has no 'time' coordinate after standardize_ds.")

    return interp_time_then_space_chunked_nearest(
        ds, df_m, ym, time_tol="90min", batch_size=10_000
    )

## --- 4. STREAM ALL MONTHS: DOWNLOAD → INTERP → WRITE → CLEAN ---
 - Loop over months
 - Compute a **tight bbox per month**
 - Download ERA5 (ZIP or NC)
 - Interpolate onto telemetry rows
 - Write one parquet per month: `era5_interp_YYYY-MM.parquet`
 - Optionally skip months that already have output

In [None]:
def process_all_months(df_all: pd.DataFrame, months, variables, out_dir, tmp_dir):
    """
    For each month in `months`:
      - Subset telemetry for that month
      - Compute month-specific bounding box
      - Download ERA5 (nc or zip)
      - Interpolate onto ship track
      - Write monthly parquet
    """
    df_all = df_all.copy()

    # tz-naive helper for xarray
    if "ts_naive" not in df_all.columns:
        df_all["ts_naive"] = df_all["timestamp"].dt.tz_convert("UTC").dt.tz_localize(None)

    out_parts = []

    for ym in months:
        ds = None
        file_path = None
        try:
            # rows in this month
            mask = df_all["ts_naive"].dt.to_period("M") == pd.Period(ym)
            df_m = df_all.loc[
                mask,
                ["row_id", "timestamp", "ts_naive", "Latitude_deg", "Longitude_deg"],
            ].copy()

            if df_m.empty:
                log(f"[skip ] {ym}: no rows in this month")
                continue

            # Month-specific bbox
            bbox_m = month_bbox(df_m, pad=2.0)
            log(f"[bbox ] {ym}: {bbox_m}")

            # Download month (nc or zip)
            target_nc = os.path.join(tmp_dir, f"ERA5_{ym}.nc")
            file_path = era5_download_month(ym, bbox_m, variables, target_nc)
            log(f"[open ] {ym}: {file_path}")

            # Interpolate
            if zipfile.is_zipfile(file_path):
                df_w = interpolate_zip_by_member_robust(file_path, df_m, ym)
            else:
                ds = open_nc(file_path)
                try:
                    log(f"[info ] {ym}: vars={list(ds.data_vars)[:8]}...")
                    log(f"[info ] {ym}: coords={list(ds.coords)} | dims={dict(ds.dims)}")
                except Exception:
                    pass
                log(f"[interp] {ym} for {len(df_m)} rows …")
                df_w = interpolate_one_file_robust(ds, df_m, ym)

            # Write monthly parquet
            part_path = os.path.join(out_dir, f"era5_interp_{ym}.parquet")
            df_w.to_parquet(part_path)
            out_parts.append(part_path)
            log(f"[done ] {ym}: wrote {part_path} (cols={list(df_w.columns)})")

        except Exception as e:
            log(f"[ERROR] {ym}: {e}")
            traceback.print_exc()

        finally:
            try:
                if ds is not None:
                    ds.close()
            except Exception:
                pass
            # If you want to clean the raw ERA5 file, you can uncomment:
            # if file_path and os.path.isfile(file_path):
            #     os.remove(file_path)
            #     log(f"[clean] removed {file_path}")

    return out_parts


# Skip months we've already processed
_done = {
    os.path.basename(p)[len("era5_interp_"):-len(".parquet")]
    for p in glob(os.path.join(OUT_DIR, "era5_interp_*.parquet"))
}
months_pending = [m for m in months if m not in _done]
log(f"Remaining months: {months_pending}")

parts = process_all_months(
    df_all   = df_era,
    months   = months_pending,
    variables= ERA5_VARS,
    out_dir  = OUT_DIR,
    tmp_dir  = ERA5_TMP_DIR,
)

log(f"wrote {len(parts)} newly processed monthly parquet parts")

13:05:00 - Remaining months: []
13:05:00 - wrote 0 monthly parquet parts


##  5. Coverage Report per Month (Optional QA)
 
 - For each monthly parquet, print:
   - number of rows
   - NaN coverage for `u10`, `v10`, `swh`, `mwp`, `mwd`


In [None]:
def coverage_report(df_like: pd.DataFrame):
    for v in ["u10", "v10", "swh", "mwp", "mwd"]:
        if v in df_like.columns:
            n = len(df_like)
            k = df_like[v].isna().sum()
            print(f"    {v}: {k}/{n} = {k/n:.2%} NaN")
    print()


parqs = sorted(glob(os.path.join(OUT_DIR, "era5_interp_*.parquet")))
if not parqs:
    raise FileNotFoundError("No monthly parquet files found in OUT_DIR.")

print(f"Found {len(parqs)} monthly files\n")
for part_path in parqs:
    ym = os.path.basename(part_path)[len("era5_interp_"):-len(".parquet")]
    w = pd.read_parquet(part_path)
    print(f"=== {ym} ===")
    print(f"Rows: {len(w)}")
    coverage_report(w)



Found 31 monthly files

=== 2023-01 ===
Rows: 28865
    u10: 171/28865 = 0.59% NaN
    v10: 171/28865 = 0.59% NaN
    swh: 2098/28865 = 7.27% NaN
    mwp: 2098/28865 = 7.27% NaN
    mwd: 2098/28865 = 7.27% NaN

=== 2023-02 ===
Rows: 23156
    u10: 127/23156 = 0.55% NaN
    v10: 127/23156 = 0.55% NaN
    swh: 395/23156 = 1.71% NaN
    mwp: 395/23156 = 1.71% NaN
    mwd: 395/23156 = 1.71% NaN

=== 2023-03 ===
Rows: 31872
    u10: 0/31872 = 0.00% NaN
    v10: 0/31872 = 0.00% NaN
    swh: 627/31872 = 1.97% NaN
    mwp: 627/31872 = 1.97% NaN
    mwd: 627/31872 = 1.97% NaN

=== 2023-04 ===
Rows: 8851
    u10: 0/8851 = 0.00% NaN
    v10: 0/8851 = 0.00% NaN
    swh: 1452/8851 = 16.40% NaN
    mwp: 1452/8851 = 16.40% NaN
    mwd: 1452/8851 = 16.40% NaN

=== 2023-05 ===
Rows: 29316
    u10: 3/29316 = 0.01% NaN
    v10: 3/29316 = 0.01% NaN
    swh: 241/29316 = 0.82% NaN
    mwp: 241/29316 = 0.82% NaN
    mwd: 241/29316 = 0.82% NaN

=== 2023-06 ===
Rows: 38509
    u10: 54/38509 = 0.14% NaN
    v10

##  6. Merge All ERA5 Parts Back onto Telemetry and Save
 
 - Concatenate all monthly ERA5 parquet files
 - Merge on `row_id` with the original telemetry
 - Quick NaN check on `u10`/`v10`
 - Save final enriched dataset as parquet

In [None]:
part_files = sorted(glob(os.path.join(OUT_DIR, "era5_interp_*.parquet")))
weather_all = pd.concat([pd.read_parquet(p) for p in part_files]).sort_index()
log(f"weather rows: {len(weather_all)} | columns: {list(weather_all.columns)}")

# Left-join on row_id
df_merged = df.merge(weather_all, on="row_id", how="left")

# Quick quality check
nan_frac = df_merged[["u10", "v10"]].isna().any(axis=1).mean()
log(f"NaN fraction in ERA5 wind columns: {nan_frac:.2%}")

# Save final enriched dataset
FINAL_OUT = os.path.join(OUT_DIR, "telemetry_with_era5_wind_waves.parquet")
df_merged.to_parquet(FINAL_OUT, index=False)
log(f"saved: {FINAL_OUT}")

13:05:01 - weather rows: 780345 | columns: ['mwd', 'mwp', 'swh', 'u10', 'v10']
13:05:01 - NaN fraction in ERA5 wind columns: 0.26%
13:05:01 - saved: ./metocean_out/telemetry_with_era5_wind_waves.parquet
