<a href="https://colab.research.google.com/github/jaysmerrill/Yaquina_Bay_Seiching/blob/main/Seiche_prediction_TFT_method.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# 46050 ingest -> align -> (optional) unified mask -> XGB top-10 (1/base) -> Mini-TFT
# New in this drop-in:
#   • Fill NOAA 9435380 hourly wind gaps using HMSC monthly .dat files
#     - Parse HMSC (skip preamble, read from TIMESTAMP row)
#     - Convert PST(UTC-8) -> UTC by +8h
#     - Convert AWS mph -> m/s; AWD deg
#     - Build u/v (meteorological from-direction), resample 5-min -> hourly (vector mean)
#     - Fill ONLY missing 9435380 hourly values with HMSC hourly values
#   • No gust anywhere (neither NOAA nor HMSC)
# ============================================================
import io, gzip, requests, re, calendar
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
from contextlib import nullcontext
from sklearn.metrics import r2_score

# -------------------
# USER KNOBS
# -------------------
DATA_PATH           = "/content/input_matrix_for_ML_AI_applications.parquet"
YEARS               = list(range(2008, 2025))
DT_MINUTES          = 6
MAX_GAP_HOURS       = 5
USE_UNIFIED_GAP_MASK= False   # keep False to avoid over-masking
HISTORY_HOURS       = 12
HORIZONS_MIN        = [60, 120, 240]
PRIMARY_HORIZON     = 60

# NOAA wind (CO-OPS) fetch window (chunked) — baseline to be filled by HMSC
NOAA_BEGIN_DATE     = "20080101"
NOAA_END_DATE       = "20250101"
NOAA_STATION        = "9435380"
NOAA_CHUNK_DAYS     = 365

# HMSC monthly archive (used to backfill gaps)
HMSC_BASE_URL       = "http://weather.hmsc.oregonstate.edu/weather/weatherproject/archive/{yyyy}/HMSC_{yyyymm}.dat"
HMSC_UTC_OFFSET_H   = +8  # PST → UTC = +8h (per user)

# lag caps (minutes)
NON_EQ_MAX_LAG_MIN  = 4 * 60
EQ_MAX_LAG_MIN      = 24 * 60

# top-k lagged channels: ONE per base
TOP_LAG_FEATURES    = 12
PER_BASE_CAP        = 1

# spike emphasis on target
WEIGHT_SPIKES       = True
SPIKE_P90           = 0.90
SPIKE_WEIGHT        = 2.0

# time splits
TRAIN_END           = pd.Timestamp("2018-12-31 23:59:59")
VAL_END             = pd.Timestamp("2022-12-31 23:59:59")
MAX_SAMPLES_SPLIT   = 60_000

# ---- Plotting controls ----
PRED_SPLIT          = "test"     # "test" or "val"
PRED_START          = pd.Timestamp("2022-01-01 00:00:00")
PRED_END            = pd.Timestamp("2022-01-31 23:59:59")
PLOT_LAST_DAYS      = 40

# TFT hyperparams
BATCH_SIZE          = 512
EPOCHS              = 6
LR                  = 2e-3
DROPOUT             = 0.2
D_MODEL             = 128
NHEAD               = 4
LSTM_LAYERS         = 1
FF_DIM              = 256

# -------------------
# Load df_seq if needed
# -------------------
if DATA_PATH is not None and "df_seq" not in globals():
    if DATA_PATH.lower().endswith(".parquet"):
        df_seq = pd.read_parquet(DATA_PATH)
    else:
        df_seq = pd.read_csv(DATA_PATH)

assert "df_seq" in globals(), "Please define df_seq or set DATA_PATH to your parquet/csv."
assert {"time","var1h"}.issubset(df_seq.columns), "df_seq must include 'time' and 'var1h'."

df_seq["time"] = pd.to_datetime(df_seq["time"])
df_seq = df_seq.sort_values("time").reset_index(drop=True)
for c in df_seq.columns:
    if c != "time":
        df_seq[c] = pd.to_numeric(df_seq[c], errors="coerce").astype("float32")

# -------------------
# Robust HTTP session
# -------------------
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

def _make_session():
    s = requests.Session()
    retry = Retry(total=4, backoff_factor=0.5, status_forcelist=(429,500,502,503,504),
                  allowed_methods=["HEAD","GET","OPTIONS"], raise_on_status=False)
    s.headers.update({"User-Agent":"Mozilla/5.0 (X11; Linux x86_64) ColabFetcher/1.0", "Accept":"*/*"})
    s.mount("https://", HTTPAdapter(max_retries=retry))
    s.mount("http://", HTTPAdapter(max_retries=retry))
    return s

session = _make_session()

# -------------------
# NDBC 46050 downloader + robust parser  (Hs/DPD/APD/MWD only)
# -------------------
YEARS = [y for y in YEARS if y < pd.Timestamp.utcnow().year]
DIRECT_PAT = "https://www.ndbc.noaa.gov/data/historical/stdmet/46050h{year}.txt.gz"
PHP_PAT    = "https://www.ndbc.noaa.gov/view_text_file.php?filename=46050h{year}.txt.gz&dir=data/historical/stdmet/"

def fetch_year_text(year:int, session:requests.Session)->str|None:
    for url in (DIRECT_PAT.format(year=year), PHP_PAT.format(year=year)):
        try:
            r = session.get(url, timeout=30)
            if r.status_code != 200 or not r.content: continue
            ctype = r.headers.get("Content-Type","").lower()
            data = r.content
            try:
                with gzip.GzipFile(fileobj=io.BytesIO(data)) as gz:
                    return gz.read().decode("utf-8", errors="ignore")
            except OSError:
                if "html" in ctype or data[:15].lower().startswith(b"<!doctype html"):
                    continue
                try:
                    return data.decode("utf-8", errors="ignore")
                except Exception:
                    continue
        except Exception:
            continue
    return None

def _ndbc_to_float(tok:str) -> float:
    s = str(tok).strip()
    if s == "" or s.upper() == "MM":
        return np.nan
    if re.fullmatch(r"-?9+(\.0+)?", s):  # 99, 9999, 99.0, etc.
        return np.nan
    try:
        return float(s)
    except Exception:
        return np.nan

def parse_ndbc_stdmet_text(txt:str)->pd.DataFrame:
    rows = []
    for ln in txt.splitlines():
        if not ln or ln.lstrip().startswith("#"): continue
        if not re.match(r"^\s*(\d{2}|\d{4})\s+\d{1,2}\s+\d{1,2}\s+\d{1,2}\s+\d{1,2}\s", ln):
            continue
        parts = re.split(r"\s+", ln.strip())
        y,mo,dy,hh,mi = parts[:5]
        try:
            y = int(y); mo=int(mo); dy=int(dy); hh=int(hh); mi=int(mi)
            year = 2000 + y if y < 100 else y
            ts = pd.Timestamp(year=year, month=mo, day=dy, hour=hh, minute=mi, tz="UTC").tz_localize(None)
        except Exception:
            continue
        vals = parts[5:]
        if len(vals) < 13: vals = (vals + ["MM"]*13)[:13]
        vals = [_ndbc_to_float(v) for v in vals]
        rows.append([ts] + vals)
    if not rows:
        return pd.DataFrame()
    cols = ["time","WDIR","WSPD","GST","WVHT","DPD","APD","MWD","PRES","ATMP","WTMP","DEWP","VIS","TIDE"]
    df = pd.DataFrame(rows, columns=cols).sort_values("time").drop_duplicates("time")
    df = df.rename(columns={"WVHT":"Hs"})
    keep = ["time","Hs","DPD","APD","MWD"]
    for k in keep:
        if k not in df.columns: df[k] = np.nan
    return df[keep].reset_index(drop=True)

frames, attempt_log = [], []
for y in YEARS:
    txt = fetch_year_text(y, session)
    if not txt: attempt_log.append((y, "Missing/HTML/404")); continue
    dfy = parse_ndbc_stdmet_text(txt)
    if dfy.empty: attempt_log.append((y, "Parsed 0 rows")); continue
    frames.append(dfy); attempt_log.append((y, f"OK {len(dfy)}"))
if not frames:
    print("No 46050 data parsed. Year statuses:", attempt_log)
    raise RuntimeError("46050 download failed.")
df_46050 = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time").reset_index(drop=True)

# Clean buoy values pre-align
for col, thresh in [("Hs", 50.0), ("DPD", 50.0), ("APD", 50.0)]:
    if col in df_46050.columns:
        m = df_46050[col] > thresh
        if m.any():
            print(f"Clean[46050 pre]: set {int(m.sum())} '{col}' values >{thresh} to NaN (e.g., 99 sentinel).")
            df_46050.loc[m, col] = np.nan
if "MWD" in df_46050.columns:
    m = df_46050["MWD"] > 360
    if m.any():
        print(f"Clean[46050 pre]: set {int(m.sum())} 'MWD' values >360 to NaN.")
        df_46050.loc[m, "MWD"] = np.nan

print("df_seq coverage:", df_seq["time"].min(), "→", df_seq["time"].max(), f"(n={len(df_seq):,})")
print("46050 raw coverage:", df_46050["time"].min(), "→", df_46050["time"].max(), f"(n={len(df_46050):,})")
print("46050 yearly fetch status (tail):", attempt_log[-10:])

# -------------------
# NOAA CO-OPS 9435380 (hourly wind) — chunked
# -------------------
def parse_noaa_wind_csv(txt:str)->pd.DataFrame:
    df = pd.read_csv(io.StringIO(txt))
    time_cols = [c for c in df.columns if str(c).strip().lower() in ("t","time","date time","date_time","date")]
    time_col = time_cols[0] if time_cols else df.columns[0]
    df["time"] = pd.to_datetime(df[time_col], errors="coerce", utc=True).dt.tz_localize(None)

    def _find(colnames, keys):
        keys = [k.lower() for k in keys]
        for c in colnames:
            s = str(c).lower()
            if any(k == s or k in s for k in keys):
                return c
        return None

    sp_col = _find(df.columns, ["s","speed"])   # m/s
    dir_col = _find(df.columns, ["d","dir","direction"])

    out = pd.DataFrame({"time": df["time"]})
    out["wind9435380_speed"] = pd.to_numeric(df[sp_col], errors="coerce") if sp_col else np.nan
    out["wind9435380_dir"]   = pd.to_numeric(df[dir_col], errors="coerce") if dir_col else np.nan
    # Sanity
    out.loc[(out["wind9435380_dir"] < 0) | (out["wind9435380_dir"] > 360), "wind9435380_dir"] = np.nan
    out.loc[(out["wind9435380_speed"] < 0) | (out["wind9435380_speed"] > 100), "wind9435380_speed"] = np.nan

    th = np.deg2rad(out["wind9435380_dir"].to_numpy(dtype=float))
    s  = out["wind9435380_speed"].to_numpy(dtype=float)
    u = -s * np.sin(th); v = -s * np.cos(th)
    out["wind9435380_u"] = u.astype(np.float32)
    out["wind9435380_v"] = v.astype(np.float32)

    out = out.sort_values("time").dropna(subset=["time"]).drop_duplicates("time")
    return out.reset_index(drop=True)

def _datestr(dt) -> str:
    return pd.Timestamp(dt).strftime("%Y%m%d")

def fetch_noaa_wind_dataframe(begin_date:str, end_date:str, station:str, days_per_chunk:int=365) -> pd.DataFrame:
    start = pd.to_datetime(begin_date, format="%Y%m%d")
    end   = pd.to_datetime(end_date,   format="%Y%m%d")
    frames = []
    cur = start
    while cur <= end:
        chunk_end = min(end, cur + pd.Timedelta(days=days_per_chunk-1))
        url = (
            "https://api.tidesandcurrents.noaa.gov/api/prod/datagetter"
            f"?begin_date={_datestr(cur)}&end_date={_datestr(chunk_end)}"
            f"&station={station}&product=wind&time_zone=gmt&interval=h&units=metric"
            "&application=DataAPI_Sample&format=csv"
        )
        try:
            r = session.get(url, timeout=60)
            if r.status_code == 200 and r.content:
                dfc = parse_noaa_wind_csv(r.content.decode("utf-8", errors="ignore"))
                if not dfc.empty: frames.append(dfc)
            else:
                print(f"NOAA wind chunk {cur:%Y-%m-%d}→{chunk_end:%Y-%m-%d} failed:", r.status_code)
        except Exception as e:
            print(f"NOAA wind chunk error {cur:%Y-%m-%d}→{chunk_end:%Y-%m-%d}:", e)
        cur = chunk_end + pd.Timedelta(days=1)

    if not frames:
        return pd.DataFrame(columns=["time","wind9435380_speed","wind9435380_dir","wind9435380_u","wind9435380_v"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

df_noaa = fetch_noaa_wind_dataframe(NOAA_BEGIN_DATE, NOAA_END_DATE, NOAA_STATION, days_per_chunk=NOAA_CHUNK_DAYS)
print("9435380 NOAA coverage:",
      (df_noaa["time"].min() if len(df_noaa) else None), "→",
      (df_noaa["time"].max() if len(df_noaa) else None), f"(n={len(df_noaa):,})")

# -------------------
# HMSC monthly .dat fetch + parse + PST->UTC + mph->m/s + u/v + hourly resample
# -------------------
def fetch_hmsc_month_text(year:int, month:int) -> str|None:
    yyyymm = f"{year}{month:02d}"
    url = HMSC_BASE_URL.format(yyyy=year, yyyymm=yyyymm)
    try:
        r = session.get(url, timeout=60)
        if r.status_code == 200 and r.content:
            return r.content.decode("utf-8", errors="ignore")
        else:
            return None
    except Exception:
        return None

def parse_hmsc_dat(txt:str) -> pd.DataFrame:
    # Find the header line that starts with TIMESTAMP
    lines = txt.splitlines()
    start_idx = None
    for i, ln in enumerate(lines):
        s = ln.strip().strip('"')
        if s.startswith("TIMESTAMP,") or s.startswith("TIMESTAMP,RECORD") or s.startswith("TIMESTAMP"):
            start_idx = i
            break
    if start_idx is None:
        return pd.DataFrame()

    # Reconstruct CSV content from header onward
    content = "\n".join(lines[start_idx:])
    df = pd.read_csv(io.StringIO(content))

    # Drop the immediately following "units/labels" rows where TIMESTAMP is TS or empty
    # Keep rows that look like actual datetimes.
    def _try_time(x):
        try:
            return pd.to_datetime(x)
        except Exception:
            return pd.NaT
    tt = df.iloc[:,0].apply(_try_time)
    df = df.loc[tt.notna()].copy()
    df.rename(columns={df.columns[0]:"TIMESTAMP"}, inplace=True)

    # Keep AWS (avg wind speed, mph) and AWD (avg wind direction, deg)
    # Columns names appear exactly as AWS / AWD in sample.
    if "AWS" not in df.columns or "AWD" not in df.columns:
        # Try case-insensitive match
        cols_lower = {c.lower(): c for c in df.columns}
        AWS_col = cols_lower.get("aws")
        AWD_col = cols_lower.get("awd")
    else:
        AWS_col, AWD_col = "AWS", "AWD"

    if AWS_col is None or AWD_col is None:
        return pd.DataFrame()

    out = pd.DataFrame()
    # Local time (PST), convert to UTC by +8h (per user)
    out["time"] = pd.to_datetime(df["TIMESTAMP"], errors="coerce") + pd.Timedelta(hours=HMSC_UTC_OFFSET_H)

    # Convert mph -> m/s
    mph = pd.to_numeric(df[AWS_col], errors="coerce")
    spd_ms = mph * 0.44704

    # Direction deg (meteorological FROM)
    dir_deg = pd.to_numeric(df[AWD_col], errors="coerce")
    dir_deg[(dir_deg < 0) | (dir_deg > 360)] = np.nan

    out["hmsc_speed_ms"] = spd_ms.astype("float32")
    out["hmsc_dir_deg"]  = dir_deg.astype("float32")

    # u/v components (meteorological FROM)
    th = np.deg2rad(out["hmsc_dir_deg"].to_numpy(dtype=float))
    s  = out["hmsc_speed_ms"].to_numpy(dtype=float)
    u = -s * np.sin(th); v = -s * np.cos(th)
    out["hmsc_u"] = u.astype(np.float32)
    out["hmsc_v"] = v.astype(np.float32)

    out = out.sort_values("time").dropna(subset=["time"]).drop_duplicates("time")
    return out.reset_index(drop=True)

def month_iter(start:pd.Timestamp, end:pd.Timestamp):
    y, m = start.year, start.month
    while pd.Timestamp(year=y, month=m, day=1) <= end:
        yield (y, m)
        if m == 12: y, m = y+1, 1
        else: m += 1

def fetch_hmsc_range(start:pd.Timestamp, end:pd.Timestamp) -> pd.DataFrame:
    frames=[]
    for y,m in month_iter(start.normalize(), end.normalize()):
        txt = fetch_hmsc_month_text(y, m)
        if not txt: continue
        dfm = parse_hmsc_dat(txt)
        if not dfm.empty:
            # Keep only within [start-1d, end+1d] to be safe
            dfm = dfm[(dfm["time"] >= start - pd.Timedelta(days=1)) &
                      (dfm["time"] <= end   + pd.Timedelta(days=1))]
            if len(dfm): frames.append(dfm)
    if not frames:
        return pd.DataFrame(columns=["time","hmsc_speed_ms","hmsc_dir_deg","hmsc_u","hmsc_v"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

# -------------------
# Build hourly NOAA baseline and fill its gaps from HMSC
# -------------------
def resample_hmsc_hourly(df_hmsc:pd.DataFrame) -> pd.DataFrame:
    if df_hmsc.empty:
        return df_hmsc
    g = df_hmsc.set_index("time").sort_index()

    # Vector mean per hour
    hourly_u = g["hmsc_u"].resample("H").mean()
    hourly_v = g["hmsc_v"].resample("H").mean()
    # Direction-from in deg
    th = np.degrees(np.arctan2(-hourly_u.to_numpy(), -hourly_v.to_numpy())) % 360.0
    # Vector-mean speed
    spd = np.sqrt(hourly_u.to_numpy()**2 + hourly_v.to_numpy()**2)

    out = pd.DataFrame({
        "time": hourly_u.index,
        "hmsc_u": hourly_u.values.astype("float32"),
        "hmsc_v": hourly_v.values.astype("float32"),
        "hmsc_speed_ms": spd.astype("float32"),
        "hmsc_dir_deg": th.astype("float32"),
    }).dropna(subset=["time"])
    return out.reset_index(drop=True)

# NOAA hourly baseline
df_noaa_hourly = df_noaa.copy()
# Ensure exact hourly index (some responses are already hourly)
if not df_noaa_hourly.empty:
    df_noaa_hourly["time"] = pd.to_datetime(df_noaa_hourly["time"])
    df_noaa_hourly = df_noaa_hourly.sort_values("time").drop_duplicates("time")

# Define the time span we need for filling: use df_noaa span; if empty, use df_seq/df_46050 overlap
if len(df_noaa_hourly):
    fill_start, fill_end = df_noaa_hourly["time"].min(), df_noaa_hourly["time"].max()
else:
    fill_start = max(df_seq["time"].min(), df_46050["time"].min())
    fill_end   = min(df_seq["time"].max(), df_46050["time"].max())

# Fetch HMSC for that span and build hourly series
df_hmsc_raw = fetch_hmsc_range(fill_start, fill_end)
print("HMSC raw coverage:",
      (df_hmsc_raw["time"].min() if len(df_hmsc_raw) else None), "→",
      (df_hmsc_raw["time"].max() if len(df_hmsc_raw) else None), f"(n={len(df_hmsc_raw):,})")
df_hmsc_hour = resample_hmsc_hourly(df_hmsc_raw)
print("HMSC hourly coverage:",
      (df_hmsc_hour["time"].min() if len(df_hmsc_hour) else None), "→",
      (df_hmsc_hour["time"].max() if len(df_hmsc_hour) else None), f"(n={len(df_hmsc_hour):,})")

# Reindex NOAA to a perfect hourly grid over [fill_start..fill_end]
hourly_grid = pd.date_range(fill_start.floor("H"), fill_end.ceil("H"), freq="H")
noaa = (df_noaa_hourly.set_index("time")
        .reindex(hourly_grid)
        .rename_axis("time")
        .reset_index())
noaa.rename(columns={
    "wind9435380_speed":"speed",
    "wind9435380_dir":"dir",
    "wind9435380_u":"u",
    "wind9435380_v":"v"
}, inplace=True)

# Identify gaps
gap_mask = noaa["speed"].isna() | noaa["dir"].isna() | noaa["u"].isna() | noaa["v"].isna()
n_gaps = int(gap_mask.sum())
print(f"9435380 NOAA hourly NaN rows before fill: {n_gaps:,}")

# Build HMSC hourly aligned to same grid
if len(df_hmsc_hour):
    hmsc = (df_hmsc_hour.set_index("time")
            .reindex(hourly_grid)
            .rename_axis("time")
            .reset_index())
else:
    hmsc = pd.DataFrame({"time": hourly_grid})
    hmsc["hmsc_speed_ms"] = np.nan
    hmsc["hmsc_dir_deg"]  = np.nan
    hmsc["hmsc_u"] = np.nan
    hmsc["hmsc_v"] = np.nan

# Fill 9435380 gaps with HMSC (only where NOAA is NaN and HMSC has a value)
fillable = gap_mask & hmsc["hmsc_speed_ms"].notna() & hmsc["hmsc_dir_deg"].notna() & hmsc["hmsc_u"].notna() & hmsc["hmsc_v"].notna()
filled_rows = int(fillable.sum())
noaa.loc[fillable, "speed"] = hmsc.loc[fillable, "hmsc_speed_ms"].values
noaa.loc[fillable, "dir"]   = hmsc.loc[fillable, "hmsc_dir_deg"].values
noaa.loc[fillable, "u"]     = hmsc.loc[fillable, "hmsc_u"].values
noaa.loc[fillable, "v"]     = hmsc.loc[fillable, "hmsc_v"].values
print(f"Filled from HMSC: {filled_rows:,} hourly rows.")
print(f"Remaining NOAA hourly NaN rows after fill: {int(noaa[['speed','dir','u','v']].isna().any(axis=1).sum()):,}")

# Final merged hourly wind series (named as 9435380 to keep downstream code simple)
df_wind_filled = pd.DataFrame({
    "time": noaa["time"],
    "wind9435380_speed": noaa["speed"].astype("float32"),
    "wind9435380_dir":   noaa["dir"].astype("float32"),
    "wind9435380_u":     noaa["u"].astype("float32"),
    "wind9435380_v":     noaa["v"].astype("float32"),
}).dropna(subset=["time"]).reset_index(drop=True)

print("Merged hourly wind (NOAA+HMSC fill) coverage:",
      (df_wind_filled["time"].min() if len(df_wind_filled) else None), "→",
      (df_wind_filled["time"].max() if len(df_wind_filled) else None), f"(n={len(df_wind_filled):,})")

# -------------------
# Align to 6-min grid (+ optional unified gap mask)
# -------------------
def interp_timegrid(t_src:pd.Series, x_src:np.ndarray, t_grid:pd.DatetimeIndex)->np.ndarray:
    s = pd.Series(x_src, index=pd.to_datetime(t_src))
    u = s.reindex(pd.to_datetime(sorted(set(s.index).union(set(t_grid)))))
    u = u.interpolate(method="time", limit_area="inside")
    return u.reindex(pd.to_datetime(t_grid)).to_numpy()

# Define overlap grid across df_seq, 46050, and wind_filled
t0 = max(
    df_seq["time"].min(),
    df_46050["time"].min(),
    (df_wind_filled["time"].min() if len(df_wind_filled) else df_seq["time"].min())
)
t1 = min(
    df_seq["time"].max(),
    df_46050["time"].max(),
    (df_wind_filled["time"].max() if len(df_wind_filled) else df_seq["time"].max())
)
t_grid = pd.date_range(t0, t1, freq=f"{DT_MINUTES}min")
print("Common grid (pre-mask):", t0, "→", t1, f"(n={len(t_grid):,})")

# Build matrix for df_seq bases
seq_cols = [c for c in df_seq.columns if c != "time"]
seq_mat = {c: interp_timegrid(df_seq["time"], df_seq[c].values, t_grid) for c in seq_cols}

# Add ONLY wave params from 46050 (Hs/DPD/APD/MWD)
for c in ["Hs","DPD","APD","MWD"]:
    seq_mat[c] = interp_timegrid(df_46050["time"], df_46050[c].values, t_grid)

# Add wind (no gust)
if len(df_wind_filled):
    for c in ["wind9435380_speed","wind9435380_dir","wind9435380_u","wind9435380_v"]:
        seq_mat[c] = interp_timegrid(df_wind_filled["time"], df_wind_filled[c].values, t_grid)

def gap_intervals(t:pd.Series, x:np.ndarray, min_gap_hours:float)->List[Tuple[pd.Timestamp,pd.Timestamp]]:
    t = pd.to_datetime(t)
    dt = np.diff(t.values).astype("timedelta64[s]").astype(int)
    jumps = np.where(dt > min_gap_hours*3600)[0]
    iv = [(t[j], t[j+1]) for j in jumps]
    s = pd.Series(x, index=t)
    isn = s.isna().to_numpy()
    if isn.any():
        starts = np.where(np.diff(np.r_[False, isn])==1)[0]
        ends   = np.where(np.diff(np.r_[isn, False])==-1)[0]
        for a,b in zip(starts, ends):
            if (t[b]-t[a]) >= pd.Timedelta(hours=min_gap_hours):
                iv.append((t[a], t[b]))
    return iv

def merge_intervals(intervals):
    if not intervals: return []
    z = sorted(intervals, key=lambda k: k[0]); out = [list(z[0])]
    for s,e in z[1:]:
        if s <= out[-1][1]: out[-1][1] = max(out[-1][1], e)
        else: out.append([s,e])
    return [(pd.to_datetime(a), pd.to_datetime(b)) for a,b in out]

if USE_UNIFIED_GAP_MASK:
    all_intervals = []
    for c in seq_cols:
        all_intervals += gap_intervals(df_seq["time"], df_seq[c].values, MAX_GAP_HOURS)
    for c in ["Hs","DPD","APD","MWD"]:
        all_intervals += gap_intervals(df_46050["time"], df_46050[c].values, MAX_GAP_HOURS)
    if len(df_wind_filled):
        for c in ["wind9435380_speed","wind9435380_dir","wind9435380_u","wind9435380_v"]:
            all_intervals += gap_intervals(df_wind_filled["time"], df_wind_filled[c].values, MAX_GAP_HOURS)
    merged_gaps = merge_intervals(all_intervals)
    gap_mask = np.zeros(len(t_grid), dtype=bool)
    for s,e in merged_gaps:
        gap_mask |= (t_grid>=s) & (t_grid<=e)
else:
    gap_mask = np.zeros(len(t_grid), dtype=bool)

df_all = pd.DataFrame({"time": t_grid})
for c,v in seq_mat.items():
    df_all[c] = v
df_all = df_all.loc[~gap_mask].reset_index(drop=True)

# Post-align cleaning + interpolation (guarantee no >50 in plots)
for col, thresh in [("Hs", 50.0), ("DPD", 50.0), ("APD", 50.0)]:
    if col in df_all.columns:
        m = df_all[col] > thresh
        if m.any():
            print(f"Clean[post-align]: set {int(m.sum())} '{col}' values >{thresh} to NaN.")
            df_all.loc[m, col] = np.nan
        df_all[col] = pd.Series(df_all[col].values, index=df_all["time"]).interpolate(
            method="time", limit_area="inside"
        ).values

if USE_UNIFIED_GAP_MASK:
    full_grid_len = len(pd.date_range(t0, t1, freq=f"{DT_MINUTES}min"))
    print("Aligned coverage (after unified mask):",
          df_all["time"].min(), "→", df_all["time"].max(), f"(n={len(df_all):,})")
    print("Masked out grid points due to gaps:",
          f"{full_grid_len - len(df_all):,}")
else:
    print("Aligned coverage (mask disabled):",
          df_all["time"].min(), "→", df_all["time"].max(), f"(n={len(df_all):,})")

# -------------------
# Feature engineering (wl_gradient only)
# -------------------
if "wl_low" in df_all.columns:
    wl = pd.Series(df_all["wl_low"].values, index=df_all["time"])
    d  = wl.diff()
    kph = int(60/DT_MINUTES)
    df_all["wl_gradient"] = (d * kph).astype("float32").to_numpy()
else:
    df_all["wl_gradient"] = np.nan

# -------------------
# Lag candidates
# -------------------
def is_eq_base(name: str) -> bool:
    n = name.lower()
    return n.startswith("eq") or ("eq_" in n)

def build_lag_candidates(df: pd.DataFrame,
                         bases: List[str],
                         dt_minutes: int) -> Tuple[pd.DataFrame, List[str], Dict[str,str], Dict[str,int]]:
    lag_cols, series = [], {}
    base_map, lag_minutes_map = {}, {}
    for b in bases:
        x = df[b].astype("float32").to_numpy()
        if not np.isfinite(x).any():  continue
        max_lag_min = EQ_MAX_LAG_MIN if is_eq_base(b) else NON_EQ_MAX_LAG_MIN
        grid_min = list(range(0, max_lag_min+1, int(30)))
        for L in grid_min:
            steps = int(round(L / dt_minutes))
            nm = f"{b}__lag_{L}min"
            lag_minutes_map[nm] = L
            base_map[nm] = b
            if steps == 0:
                series[nm] = x
            else:
                s = pd.Series(x); series[nm] = s.shift(steps).to_numpy(dtype="float32")
            lag_cols.append(nm)
    lag_df = pd.DataFrame(series).reindex(range(len(df))).reset_index(drop=True)
    return lag_df, lag_cols, base_map, lag_minutes_map

# Candidate bases: all numeric columns, excluding a few
exclude_cols = set([
    "time","var1h","var1h_log1p","wsp","wdir","gust","atmvar1_boost"  # gust excluded
])
cand_bases = [c for c in df_all.columns if (c not in exclude_cols and df_all[c].dtype != object)]
lag_df, lag_cols, base_map, lag_minutes_map = build_lag_candidates(df_all, cand_bases, DT_MINUTES)

# -------------------
# XGB ranking (1 best lag per base)
# -------------------
from xgboost import XGBRegressor
h_steps = max(1, int(round(PRIMARY_HORIZON/DT_MINUTES)))
tgt = df_all["var1h"].shift(-h_steps).reset_index(drop=True).astype("float32")
rank_data = pd.concat([lag_df, tgt.rename("target")], axis=1).dropna().reset_index(drop=True)
if rank_data.empty: raise RuntimeError("No data available for XGB ranking (after lagging and dropna).")

X_rank = rank_data[lag_cols].to_numpy(dtype=np.float32)
y_rank = rank_data["target"].to_numpy(dtype=np.float32)

xgb = XGBRegressor(
    n_estimators=800, learning_rate=0.03, max_depth=5,
    subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
    objective="reg:squarederror", n_jobs=-1, random_state=42
)
split = max(1, int(0.8*len(X_rank)))
xgb.fit(X_rank[:split], y_rank[:split], verbose=False)
importances = dict(zip(lag_cols, xgb.feature_importances_))

best_lag_per_base = {}
for nm, imp in importances.items():
    b = base_map[nm]
    prev = best_lag_per_base.get(b)
    if (prev is None) or (imp > prev[1]): best_lag_per_base[b] = (nm, float(imp))
best_list = [(b, nm, imp) for b,(nm,imp) in best_lag_per_base.items()]
best_list.sort(key=lambda t: t[2], reverse=True)
best_list = best_list[:TOP_LAG_FEATURES]

selected_lag_cols = [nm for _, nm, _ in best_list]
selected_bases    = [b  for b, _, _ in best_list]
print("\nSelected ONE best lag per base:")
for b, nm, imp in best_list:
    print(f"  {nm:<30s}  (base={b}, importance={imp:.6f}, lag={lag_minutes_map[nm]} min)")

df_lag_all = pd.concat(
    [df_all[["time","var1h"]].reset_index(drop=True),
     lag_df[selected_lag_cols].reset_index(drop=True)],
    axis=1
)

# -------------------
# Mini-TFT (seq-to-one)
# -------------------
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

use_gpu = torch.cuda.is_available()
if use_gpu:
    torch.backends.cudnn.benchmark = True
    if hasattr(torch, "set_float32_matmul_precision"):
        torch.set_float32_matmul_precision("high")

def _amp_ctx(): return nullcontext()

class GLU(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.v = nn.Linear(d, d)
        self.g = nn.Linear(d, d)
    def forward(self, x):
        return torch.sigmoid(self.g(x)) * self.v(x)

class MiniTFT(nn.Module):
    def __init__(self, num_features:int, d_model:int=128, nhead:int=4, lstm_layers:int=1,
                 ff_dim:int=256, dropout:float=0.2):
        super().__init__()
        self.F = num_features
        self.d = d_model
        self.feat_emb = nn.ModuleList([nn.Linear(1, d_model) for _ in range(num_features)])
        self.vsn_w = nn.Linear(num_features, num_features)
        self.lstm = nn.LSTM(input_size=d_model, hidden_size=d_model,
                            num_layers=lstm_layers, batch_first=True, dropout=0.0, bidirectional=False)
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.proj = nn.Sequential(
            nn.Linear(2*d_model, ff_dim), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model), GLU(d_model),
            nn.Linear(d_model, 1)
        )
        self.dropout = nn.Dropout(dropout)
        self.last_alpha = None
    def forward(self, x):  # (B,T,F)
        B,T,F = x.shape
        x4 = x.unsqueeze(-1)  # (B,T,F,1)
        embs = [ self.feat_emb[f](x4[:,:,f,:]) for f in range(self.F) ]  # each (B,T,d)
        E = torch.stack(embs, dim=2)            # (B,T,F,d)
        logits = self.vsn_w(x)                  # (B,T,F)
        alpha = torch.softmax(logits, dim=2).unsqueeze(-1)  # (B,T,F,1)
        self.last_alpha = alpha.detach()
        Z = (alpha * E).sum(dim=2)              # (B,T,d)
        Z = self.dropout(Z)
        H, _ = self.lstm(Z)                     # (B,T,d)
        last = H[:, -1:, :]
        ctx, _ = self.attn(query=last, key=H, value=H, need_weights=False)
        cat = torch.cat([last, ctx], dim=-1).squeeze(1)
        out = self.proj(cat).squeeze(-1)
        return out

def mean_vsn_weights(model:MiniTFT, X:np.ndarray, batch_size:int=4096)->np.ndarray:
    if len(X)==0: return np.array([])
    device = next(model.parameters()).device
    ds = TensorDataset(torch.from_numpy(X))
    dl = DataLoader(ds, batch_size=min(batch_size, len(ds)), shuffle=False, drop_last=False)
    acc = None; n=0
    model.eval()
    with torch.no_grad(), _amp_ctx():
        for (xb,) in dl:
            xb = xb.to(device); _ = model(xb)
            a = model.last_alpha.squeeze(-1)  # (B,T,F)
            a = a.mean(dim=(0,1)).cpu().numpy()
            if acc is None: acc = a
            else: acc += a
            n += 1
    return acc / max(n,1)

def rmse(a,b): return float(np.sqrt(np.mean((np.asarray(a)-np.asarray(b))**2)))

# ---- Safe robust scaling ----
def _safe_robust_params(X: np.ndarray, q_low=25.0, q_high=75.0, eps=1e-6):
    Xf = X.reshape(-1, X.shape[2])
    med = np.nanmedian(Xf, axis=0)
    q1  = np.nanpercentile(Xf, q_low, axis=0)
    q3  = np.nanpercentile(Xf, q_high, axis=0)
    iqr = q3 - q1
    scale = np.where(iqr <= eps, 1.0, iqr)
    return med.astype(np.float32), scale.astype(np.float32)

def fit_scaler_from_windows_safe(X_tr, X_va=None, X_te=None):
    for X in (X_tr, X_va, X_te):
        if X is not None and len(X) > 0:
            med, scale = _safe_robust_params(X)
            return {"center": med, "scale": scale}
    F = X_tr.shape[2] if (X_tr is not None and X_tr.size) else (X_va.shape[2] if (X_va is not None and X_va.size) else (X_te.shape[2] if (X_te is not None and X_te.size) else 1))
    return {"center": np.zeros(F, np.float32), "scale": np.ones(F, np.float32)}

def norm_windows(X, rs, clip=20.0):
    if len(X) == 0: return X
    F = X.shape[2]
    C = rs["center"].reshape(1, 1, F)
    S = rs["scale" ].reshape(1, 1, F)
    Z = (X - C) / S
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    if clip is not None: Z = np.clip(Z, -clip, clip, out=Z)
    return Z

def _assert_finite(name, *arrays):
    for a in arrays:
        if a is None or len(a) == 0: continue
        if not np.isfinite(a).all(): raise ValueError(f"{name}: found non-finite values.")

# ---- Multipliers (unchanged)
WIND_STRESS_PATTERNS = [r"(^|_)taux($|_)", r"(^|_)tauy($|_)", r"wind[_-]?stress", r"^Twind", r"_stress$"]
def is_wind_stress_base(base:str)->bool:
    b = base if isinstance(base,str) else str(base)
    return any(re.search(p, b, flags=re.IGNORECASE) for p in WIND_STRESS_PATTERNS)
def compute_base_multipliers(feature_bases:List[str])->np.ndarray:
    return np.array([0.8 if is_wind_stress_base(b) else 1.0 for b in feature_bases], dtype=np.float32)

# -------------------
# Utility: split summary printer
# -------------------
def _fmt(ts):
    return ts.strftime("%Y-%m-%d %H:%M") if (ts is not None and pd.notna(ts)) else "—"
def print_split_summary(horizon, t_tr, t_va, t_te):
    def rng(t):
        if len(t)==0: return (None,None,0)
        return (t.min(), t.max(), len(t))
    tr0,tr1,ntr = rng(t_tr); va0,va1,nva = rng(t_va); te0,te1,nte = rng(t_te)
    print(f"[{horizon}m] Split ranges:")
    print(f"  Train: {_fmt(tr0)} → {_fmt(tr1)}  (n={ntr})")
    print(f"  Valid: {_fmt(va0)} → {_fmt(va1)}  (n={nva})")
    print(f"  Test : {_fmt(te0)} → {_fmt(te1)}  (n={nte})")
    return (tr0,tr1),(va0,va1),(te0,te1)

# -------------------
# Windows + splits + training
# -------------------
def make_windows_from_lagged(df:pd.DataFrame, lag_cols:List[str],
                             history_hours:int, horizon_min:int, dt_minutes:int,
                             stride:int=1):
    steps_hist  = int(round(history_hours*60/dt_minutes))
    steps_ahead = int(round(horizon_min/dt_minutes))
    tab = df[["time","var1h"] + lag_cols].copy()
    tab = tab.dropna(subset=lag_cols + ["var1h"]).reset_index(drop=True)
    if len(tab) < steps_hist + steps_ahead + 1:
        return np.empty((0,steps_hist,len(lag_cols))), np.empty((0,)), pd.DatetimeIndex([])
    X_list, y_list, t_list = [], [], []
    F = len(lag_cols); Ttot = len(tab)
    for end_idx in range(steps_hist, Ttot - steps_ahead + 1, stride):
        s = end_idx - steps_hist
        k = end_idx - 1 + steps_ahead
        xb = tab.iloc[s:end_idx][lag_cols].values
        yv = tab.iloc[k]["var1h"]
        if np.any(~np.isfinite(xb)) or not np.isfinite(yv): continue
        X_list.append(xb.astype(np.float32))
        y_list.append(float(yv))
        t_list.append(pd.to_datetime(tab.iloc[k]["time"]))
    if not X_list:
        return np.empty((0,steps_hist,F)), np.empty((0,)), pd.DatetimeIndex([])
    return np.stack(X_list), np.array(y_list, dtype=np.float32), pd.DatetimeIndex(t_list)

def adaptive_index_split(X, y, t_idx, train_frac=0.70, val_frac=0.15):
    n = len(X)
    if n < 10:
        return (X[:0], y[:0], t_idx[:0]), (X[:0], y[:0], t_idx[:0]), (X, y, t_idx)
    a = max(1, int(round(train_frac * n)))
    b = max(a+1, int(round((train_frac + val_frac) * n)))
    b = min(b, n-1)
    return (X[:a], y[:a], t_idx[:a]), (X[a:b], y[a:b], t_idx[a:b]), (X[b:], y[b:], t_idx[b:])

def cap_split(X_, y_, t_, cap=MAX_SAMPLES_SPLIT):
    if len(X_) <= cap: return X_, y_, t_
    idx = np.linspace(0, len(X_) - 1, cap, dtype=int)
    return X_[idx], y_[idx], t_[idx]

# Helper: plot raw inputs in window ordered by importance
def plot_inputs_strip(df_all, best_list, start, end, title_prefix=""):
    win = df_all[(df_all["time"] >= start) & (df_all["time"] <= end)].copy()
    if win.empty:
        print("  [inputs strip] no data in window."); return
    n = len(best_list)
    fig, axes = plt.subplots(n, 1, figsize=(14, max(2.0*n, 4.0)), sharex=True)
    if n == 1: axes = [axes]
    for ax, (base, sel_name, imp) in zip(axes, best_list):
        if base not in win.columns:
            ax.text(0.5, 0.5, f"{base} missing", transform=ax.transAxes, ha="center", va="center")
            ax.grid(alpha=0.3); continue
        s = win[base].astype(float).copy()
        if base in ("Hs","DPD","APD"):
            s[s > 50] = np.nan
        ax.plot(win["time"], s, lw=1.0)
        lag_min = re.search(r"__lag_(\d+)min", sel_name)
        lag_txt = f"{lag_min.group(1)} min" if lag_min else "?"
        ax.set_ylabel(base)
        ax.set_title(f"{base}  (selected: {sel_name}, lag={lag_txt}, imp={imp:.3g})", fontsize=9)
        ax.grid(alpha=0.3)
    axes[-1].set_xlabel("Time")
    fig.suptitle(f"{title_prefix} raw inputs [{start:%Y-%m-%d} … {end:%Y-%m-%d}]", y=0.995, fontsize=11)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

# -------------------
# Main loop
# -------------------
results = {}
for horizon in HORIZONS_MIN:
    print(f"\n=== Horizon {horizon} min ===")
    X, y, t_idx = make_windows_from_lagged(df_lag_all, selected_lag_cols,
                                           HISTORY_HOURS, horizon, DT_MINUTES, stride=1)

    if len(t_idx):
        print(f"[{horizon}m] Windows available:", t_idx.min(), "→", t_idx.max(), f"(n={len(t_idx):,})")
    else:
        print(f"[{horizon}m] No windows after lagging/masking."); continue

    # Time-based split
    mtr = t_idx <= TRAIN_END
    mva = (t_idx > TRAIN_END) & (t_idx <= VAL_END)
    mte = t_idx > VAL_END

    X_tr, y_tr, t_tr = X[mtr], y[mtr], t_idx[mtr]
    X_va, y_va, t_va = X[mva], y[mva], t_idx[mva]
    X_te, y_te, t_te = X[mte], y[mte], t_idx[mte]

    if len(X_tr)==0 or len(X_va)==0 or len(X_te)==0:
        print("  Time split produced an empty set; using adaptive 70/15/15 index split.")
        (X_tr, y_tr, t_tr), (X_va, y_va, t_va), (X_te, y_te, t_te) = adaptive_index_split(X, y, t_idx)

    if len(X_te) == 0:
        print("  Still no test after adaptive split; assigning last ~15% as test.")
        n = len(X); cut = max(1, int(0.85*n))
        X_tr, y_tr, t_tr = X[:cut], y[:cut], t_idx[:cut]
        X_va, y_va, t_va = X[:0], y[:0], t_idx[:0]
        X_te, y_te, t_te = X[cut:], y[cut:], t_idx[cut:]

    (_, _), (va_rng0, va_rng1), (te_rng0, te_rng1) = print_split_summary(horizon, t_tr, t_va, t_te)

    X_tr, y_tr, t_tr = cap_split(X_tr, y_tr, t_tr)
    X_va, y_va, t_va = cap_split(X_va, y_va, t_va)
    X_te, y_te, t_te = cap_split(X_te, y_te, t_te)

    rs     = fit_scaler_from_windows_safe(X_tr, X_va, X_te)
    X_tr_n = norm_windows(X_tr, rs)
    X_va_n = norm_windows(X_va, rs)
    X_te_n = norm_windows(X_te, rs)
    _assert_finite("X_tr_n", X_tr_n); _assert_finite("X_va_n", X_va_n); _assert_finite("X_te_n", X_te_n)
    _assert_finite("y_tr", y_tr); _assert_finite("y_va", y_va); _assert_finite("y_te", y_te)

    feature_bases = [base_map[c] for c in selected_lag_cols]
    base_mult_vec = compute_base_multipliers(feature_bases)

    def apply_base_mult(xnp):
        if base_mult_vec is None or len(base_mult_vec)==0: return xnp
        return (xnp * base_mult_vec.reshape(1,1,-1)).astype(np.float32)

    # Train Mini-TFT
    import torch
    from torch.utils.data import TensorDataset, DataLoader
    device = torch.device("cuda" if use_gpu else "cpu")
    model = MiniTFT(
        num_features=X_tr_n.shape[2], d_model=D_MODEL, nhead=NHEAD,
        lstm_layers=LSTM_LAYERS, ff_dim=FF_DIM, dropout=DROPOUT
    ).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=min(LR, 1e-3))

    ds_tr = TensorDataset(torch.from_numpy(apply_base_mult(X_tr_n)), torch.from_numpy(y_tr))
    ds_va = TensorDataset(torch.from_numpy(apply_base_mult(X_va_n)), torch.from_numpy(y_va))
    dl_tr = DataLoader(ds_tr, batch_size=min(BATCH_SIZE, len(ds_tr)), shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=min(BATCH_SIZE, len(ds_va)), shuffle=False, drop_last=False)

    best = np.inf; best_state=None; patience=3; bad=0
    q = np.quantile(y_tr, SPIKE_P90) if (WEIGHT_SPIKES and len(y_tr)>0) else None

    for ep in range(EPOCHS):
        model.train(); tr_loss=0.0
        for xb,yb in dl_tr:
            xb,yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            xb = torch.nan_to_num(xb, nan=0.0, posinf=0.0, neginf=0.0)
            yb = torch.nan_to_num(yb, nan=0.0, posinf=0.0, neginf=0.0)
            with _amp_ctx():
                pred = model(xb)
                l = (pred - yb) ** 2
                if q is not None:
                    q_t = torch.tensor(q, device=yb.device, dtype=yb.dtype)
                    w = torch.where(yb >= q_t, torch.tensor(SPIKE_WEIGHT, device=yb.device, dtype=yb.dtype),
                                              torch.tensor(1.0,         device=yb.device, dtype=yb.dtype))
                    l = l * w
                loss = l.mean()
            if not torch.isfinite(loss):
                raise RuntimeError("Non-finite training loss encountered.")
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            tr_loss += loss.item() * len(xb)
        tr_loss /= max(1, len(dl_tr.dataset))

        model.eval(); va_loss=0.0
        with torch.no_grad(), _amp_ctx():
            for xb,yb in dl_va:
                xb,yb = xb.to(device), yb.to(device)
                xb = torch.nan_to_num(xb, nan=0.0, posinf=0.0, neginf=0.0)
                out = model(xb)
                va_loss += nn.functional.mse_loss(out, yb, reduction="sum").item()
        va_loss /= max(1, len(dl_va.dataset))
        print(f"[TFT] epoch {ep+1:02d}  trainMSE={tr_loss:.5f}  valMSE={va_loss:.5f}")

        if va_loss + 1e-6 < best:
            best = va_loss; best_state = model.state_dict(); bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("[TFT] early stop."); break
    if best_state is not None: model.load_state_dict(best_state)

    # Predict on both val & test
    def predict_tft(model, X):
        ds = TensorDataset(torch.from_numpy(X))
        dl = DataLoader(ds, batch_size=min(BATCH_SIZE, len(ds)), shuffle=False, drop_last=False)
        out=[]; model.eval()
        with torch.no_grad(), _amp_ctx():
            for (xb,) in dl:
                xb = xb.to(device)
                xb = torch.nan_to_num(xb, nan=0.0, posinf=0.0, neginf=0.0)
                out.append(model(xb).cpu().numpy())
        return np.concatenate(out, axis=0) if len(out)>0 else np.array([])

    yp_va = predict_tft(model, apply_base_mult(X_va_n)) if len(X_va_n)>0 else np.array([])
    yp_te = predict_tft(model, apply_base_mult(X_te_n)) if len(X_te_n)>0 else np.array([])

    if len(yp_te)>0 and len(y_te)>0:
        R2 = r2_score(y_te, yp_te); RMSE = rmse(y_te, yp_te)
    else:
        R2 = float("nan"); RMSE = float("nan")
    results[horizon] = dict(t_val=t_va, y_val=y_va, yp_val=yp_va,
                            t=t_te, y=y_te, yp=yp_te, r2=R2, rmse=RMSE)
    print(f"[TFT | 1/base] {horizon}m  Test R²={R2:.3f}  RMSE={RMSE:.3f}")

    # Variable-selection weights (test)
    try:
        w_mean = mean_vsn_weights(model, apply_base_mult(X_te_n))
        if w_mean.size:
            print("  [TFT] mean variable-selection weights (test):")
            for f_idx, w in sorted(list(enumerate(w_mean)), key=lambda t: -t[1])[:min(12, len(w_mean))]:
                base = selected_bases[f_idx] if f_idx < len(selected_bases) else f"feat{f_idx}"
                print(f"    {base:<24s}  {w:.4f}")
    except Exception:
        pass

    # ---- Plot window
    if PRED_SPLIT.lower() == "val":
        t_plot, y_plot, yhat_plot = t_va, y_va, yp_va
        split_name = "Validation"
        avail0, avail1 = va_rng0, va_rng1
    else:
        t_plot, y_plot, yhat_plot = t_te, y_te, yp_te
        split_name = "Test"
        avail0, avail1 = te_rng0, te_rng1

    if len(t_plot)==0:
        print(f"  [plot] No {split_name.lower()} points available to plot.")
        continue

    mwin = (t_plot >= PRED_START) & (t_plot <= PRED_END)
    if not np.any(mwin):
        print(f"  [plot] No {split_name.lower()} predictions in requested window "
              f"{PRED_START:%Y-%m-%d} to {PRED_END:%Y-%m-%d}.")
        print(f"         Available {split_name.lower()} range: {_fmt(avail0)} → {_fmt(avail1)}")
        end = t_plot[-1]; start = end - pd.Timedelta(days=PLOT_LAST_DAYS)
        mlast = (t_plot >= start) & (t_plot <= end)
        if np.any(mlast):
            ts, yt, ypp = t_plot[mlast], y_plot[mlast], yhat_plot[mlast]
            plt.figure(figsize=(14,4))
            plt.plot(ts, yt, label="Measured", lw=1.2)
            plt.plot(ts, ypp, "--", label=f"Predicted ({split_name})", lw=1.2)
            plt.title(f"TFT — {horizon} min ({split_name}, last {PLOT_LAST_DAYS} days)")
            plt.ylabel("var1h"); plt.xlabel("Time"); plt.grid(alpha=0.3); plt.legend()
            plt.tight_layout(); plt.show()
            plot_inputs_strip(df_all, best_list, start=start, end=end,
                              title_prefix=f"Horizon {horizon} min — {split_name} —")
        continue

    ts, yt, ypp = t_plot[mwin], y_plot[mwin], yhat_plot[mwin]
    plt.figure(figsize=(14,4))
    plt.plot(ts, yt, label="Measured", lw=1.2)
    plt.plot(ts, ypp, "--", label=f"Predicted ({split_name})", lw=1.2)
    plt.title(f"TFT — {horizon} min ({split_name}) — {PRED_START:%b %Y}")
    plt.ylabel("var1h"); plt.xlabel("Time"); plt.grid(alpha=0.3); plt.legend()
    plt.tight_layout(); plt.show()

    plot_inputs_strip(
        df_all=df_all,
        best_list=best_list,
        start=PRED_START,
        end=PRED_END,
        title_prefix=f"Horizon {horizon} min — {split_name} —"
    )

print("\nFinal selected lag channels (one per base):")
for nm in selected_lag_cols:
    print("  ", nm)


df_seq coverage: 2008-03-05 17:00:00 → 2023-04-17 04:36:00 (n=1,269,412)
46050 raw coverage: 2008-03-05 17:50:00 → 2024-12-31 23:50:00 (n=470,547)
46050 yearly fetch status (tail): [(2015, 'OK 8472'), (2016, 'OK 8688'), (2017, 'OK 28478'), (2018, 'OK 52403'), (2019, 'OK 52055'), (2020, 'OK 52211'), (2021, 'OK 51561'), (2022, 'OK 52531'), (2023, 'OK 52486'), (2024, 'OK 52695')]
9435380 NOAA coverage: 2008-01-01 00:00:00 → 2025-01-01 23:00:00 (n=138,914)
