<a href="https://colab.research.google.com/github/jaysmerrill/Yaquina_Bay_Seiching/blob/main/seiche_prediction_realtime_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

# ============================================================
# Seiche Pipeline — Full Rebuild (with all requested updates)
# ============================================================
import io, gzip, requests, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
from contextlib import nullcontext
from sklearn.metrics import r2_score

# Progress bars
try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, *a, **k):  # fallback if tqdm unavailable
        return x

# -------------------
# USER KNOBS
# -------------------
YEARS               = list(range(2008, 2026))   # swden/stdmet years
DT_MINUTES          = 6                         # base grid step (minutes)

TFT_TARGET_LOG1P = True   # train TFT on log1p(var1h); invert with expm1 for outputs

# --- Baro & XGB selection controls ---
BARO_FORCE_IN_XGB   = True            # always include baro features (best lag)
BARO_FORCE_LAGS_MIN = [0, 60, 120, 240]  # minutes to try for baro/hpa & baro_var
BARO_SPIKE_Q        = 0.90            # spike threshold quantile in transformed target space
BARO_SPIKE_LOSS_W   = 2.0             # extra loss weight when baro_var is spiking

# ---- Response processing knobs ----
VAR_WIN_HOURS       = 2.0                       # centered rolling-variance window (hours)
BANDPASS_MIN_MIN    = 20.0                      # lower period (minutes)
BANDPASS_MAX_MIN    = 30.0                      # upper period (minutes)
BP_ORDER            = 6                         # Butterworth order

MAX_GAP_HOURS       = 5
USE_UNIFIED_GAP_MASK= False

# ---- Windowing / horizons ----
HISTORY_HOURS       = 12
HORIZONS_MIN        = [60, 120, 240]
PRIMARY_HORIZON     = 60

# If you already loaded df_wl6 yourself (with columns ['time','wl']), set False.
WL_USE_FETCH        = True

# NOAA wind (CO-OPS) fetch window (chunked)
NOAA_BEGIN_DATE     = "20080101"
NOAA_END_DATE       = "20251001"
NOAA_WIND_STATION   = "9435380"   # Newport / South Beach

# Water level (response) station and date range — default to 9435380
NOAA_WL_STATION     = "9435380"
WL_BEGIN_DATE       = "20080101"
WL_END_DATE         = "20251001"
NOAA_CHUNK_DAYS     = 365

# HMSC monthly archive (backfill for wind & baro)
HMSC_BASE_URL       = "http://weather.hmsc.oregonstate.edu/weather/weatherproject/archive/{yyyy}/HMSC_{yyyymm}.dat"
HMSC_UTC_OFFSET_H   = +8  # PST → UTC

# ---- Lag feature caps ----
NON_EQ_MAX_LAG_MIN  = 4 * 60
EQ_MAX_LAG_MIN      = 24 * 60
TOP_LAG_FEATURES    = 12
PER_BASE_CAP        = 1

# ---- Spike weighting on target ----
WEIGHT_SPIKES       = True
SPIKE_P90           = 0.90
SPIKE_WEIGHT        = 2.0

# ---- Time splits ----
TRAIN_END           = pd.Timestamp("2018-12-31 23:59:59")
VAL_END             = pd.Timestamp("2022-12-31 23:59:59")
MAX_SAMPLES_SPLIT   = 60_000

# ---- Plotting controls ----
PRED_SPLIT          = "test"     # "test" or "val"
PRED_START          = pd.Timestamp("2022-01-01 00:00:00")
PRED_END            = pd.Timestamp("2022-01-31 23:59:59")
PLOT_LAST_DAYS      = 40
PLOT_TOP_INPUTS     = 6          # max number of input drivers to show per horizon

# ===================
#  MODEL KNOBS
# ===================

# XGBoost knobs (used for lag ranking / feature selection)
XGB_N_ESTIMATORS    = 800
XGB_LEARNING_RATE   = 0.03
XGB_MAX_DEPTH       = 5
XGB_SUBSAMPLE       = 0.8
XGB_COLSAMPLE       = 0.8
XGB_L2              = 1.0                        # reg_lambda
XGB_VERBOSE         = 0                          # 0=quiet

# Mini-TFT knobs
TFT_D_MODEL         = 128
TFT_NHEAD           = 4
TFT_LSTM_LAYERS     = 1
TFT_FF_DIM          = 256
TFT_DROPOUT         = 0.2
TFT_LR              = 2e-3
TFT_EPOCHS          = 12
TFT_PATIENCE        = 3
TFT_BATCH           = 512

# TFT output constraints (variance ≥ 0)
TFT_NONNEG_OUTPUT   = True
TFT_VAR_FLOOR       = 1e-8

# ---- HMSC / wind weighting / baro variance knobs ----
WIND_WEIGHT                 = 0.35      # scale applied to wind speed (0=no wind influence, 1=original)
BARO_COL_KEYS               = ["bp", "baro", "barometer", "press", "pressure", "slp"]
BARO_VAR_HIGHPASS_HOURS     = 4.0       # high-pass cutoff period (hours) before variance
BARO_VAR_WIN_HOURS          = 4.0       # centered rolling variance window length (hours)
BARO_CLIP_MIN_HPA           = 900.0
BARO_CLIP_MAX_HPA           = 1050.0

# -------------------
# Robust HTTP session
# -------------------
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

def _make_session():
    s = requests.Session()
    retry = Retry(total=4, backoff_factor=0.6, status_forcelist=(429,500,502,503,504),
                  allowed_methods=["HEAD","GET","OPTIONS"], raise_on_status=False)
    s.headers.update({"User-Agent":"Mozilla/5.0 (X11; Linux x86_64) SeicheFetcher/1.0", "Accept":"*/*"})
    s.mount("https://", HTTPAdapter(max_retries=retry))
    s.mount("http://", HTTPAdapter(max_retries=retry))
    return s

session = _make_session()

# -------------------
# Helpers
# -------------------
def _ndbc_to_float(tok:str) -> float:
    s = str(tok).strip()
    if s == "" or s.upper() == "MM":
        return np.nan
    if re.fullmatch(r"-?9+(\.0+)?", s):  # 99, 9999, 999.00, etc. -> NaN
        return np.nan
    try:
        return float(s)
    except Exception:
        return np.nan

def _datestr(dt) -> str:
    return pd.Timestamp(dt).strftime("%Y%m%d")

# -------------------
# NDBC 46050 STDMET (Hs/DPD/APD/MWD)
# -------------------
STDMET_DIRECT_PAT = "https://www.ndbc.noaa.gov/data/historical/stdmet/46050h{year}.txt.gz"
STDMET_PHP_PAT    = "https://www.ndbc.noaa.gov/view_text_file.php?filename=46050h{year}.txt.gz&dir=data/historical/stdmet/"

def fetch_stdmet_year(year:int)->str|None:
    for url in (STDMET_DIRECT_PAT.format(year=year), STDMET_PHP_PAT.format(year=year)):
        try:
            r = session.get(url, timeout=30)
            if r.status_code != 200 or not r.content: continue
            data = r.content
            try:
                with gzip.GzipFile(fileobj=io.BytesIO(data)) as gz:
                    return gz.read().decode("utf-8", errors="ignore")
            except OSError:
                try:
                    return data.decode("utf-8", errors="ignore")
                except Exception:
                    continue
        except Exception:
            continue
    return None

def parse_stdmet_text(txt:str)->pd.DataFrame:
    rows=[]
    for ln in txt.splitlines():
        if not ln or ln.lstrip().startswith("#"): continue
        parts = re.split(r"\s+", ln.strip())
        if len(parts) < 6: continue
        y,mo,dy,hh,mi = parts[:5]
        try:
            y=int(y); mo=int(mo); dy=int(dy); hh=int(hh); mi=int(mi)
            year = 2000+y if y<100 else y
            ts = pd.Timestamp(year=year, month=mo, day=dy, hour=hh, minute=mi, tz="UTC").tz_localize(None)
        except Exception:
            continue
        vals = parts[5:]
        if len(vals) < 13: vals = (vals + ["MM"]*13)[:13]
        vals = [_ndbc_to_float(v) for v in vals]
        rows.append([ts] + vals)
    if not rows:
        return pd.DataFrame()
    cols = ["time","WDIR","WSPD","GST","WVHT","DPD","APD","MWD","PRES","ATMP","WTMP","DEWP","VIS","TIDE"]
    df = pd.DataFrame(rows, columns=cols).sort_values("time").drop_duplicates("time")
    df = df.rename(columns={"WVHT":"Hs"}).reset_index(drop=True)
    for c,thr in [("Hs",50.0),("DPD",50.0),("APD",50.0)]:
        if c in df.columns:
            m = df[c] > thr
            if m.any(): df.loc[m,c]=np.nan
    if "MWD" in df.columns:
        df.loc[(df["MWD"]<0)|(df["MWD"]>360),"MWD"]=np.nan
    keep = ["time","Hs","DPD","APD","MWD"]
    for k in keep:
        if k not in df.columns: df[k]=np.nan
    return df[keep]

def build_stdmet(years: List[int]) -> pd.DataFrame:
    frames = []
    for y in tqdm(years, desc="STDMET years", unit="yr"):
        txt = fetch_stdmet_year(y)
        if not txt: continue
        dfy = parse_stdmet_text(txt)
        if dfy.empty: continue
        frames.append(dfy)
    if not frames: return pd.DataFrame(columns=["time","Hs","DPD","APD","MWD"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

df_stdmet = build_stdmet(YEARS)

# -------------------
# NDBC 46050 SWDEN → 3-band Hm0
# -------------------
SWDEN_PHP_PATS = [
    "https://www.ndbc.noaa.gov/view_text_file.php?filename=46050w{year}.txt.gz&dir=data/historical/swden/",
    "https://www.ndbc.noaa.gov/view_text_file.php?filename=46050w{year}.txt.gz&dir=data/swden/46050/",
]

def fetch_swden_year(year:int)->str|None:
    for url in SWDEN_PHP_PATS:
        try:
            r = session.get(url.format(year=year), timeout=40)
            if r.status_code==200 and r.text and "Page Not Found" not in r.text:
                return r.text
        except Exception:
            continue
    return None

def split_swden_blocks(lines:List[str]):
    header=[]; i=0
    while i<len(lines) and lines[i].startswith("#"):
        header.append(lines[i]); i+=1
    while i<len(lines) and not lines[i].strip():
        i+=1
    freq_line = lines[i].rstrip("\n") if i<len(lines) else None
    i += 1 if i<len(lines) else 0
    data=[ln for ln in lines[i:] if ln.strip()]
    return header, freq_line, data

def parse_swden_year(txt:str)->pd.DataFrame:
    lines = txt.splitlines()
    _, freq_line, data_lines = split_swden_blocks(lines)
    if not freq_line: return pd.DataFrame()
    freq_tokens = [t for t in re.split(r"\s+", freq_line.strip()) if t]
    if any(k in freq_tokens[:6] for k in ("YY","MM","DD","hh")):
        idx = lines.index(freq_line)
        j=idx+1
        while j<len(lines) and not lines[j].strip():
            j+=1
        if j<len(lines):
            freq_tokens = [t for t in re.split(r"\s+", lines[j].strip()) if t]
            data_lines  = [ln for ln in lines[j+1:] if ln.strip()]
        else:
            return pd.DataFrame()

    def _num(tok):
        try: return float(tok)
        except: return np.nan
    freqs = np.array([_num(t) for t in freq_tokens], dtype=float)
    freqs = freqs[np.isfinite(freqs)]
    if freqs.size==0: return pd.DataFrame()

    dfw = np.empty_like(freqs)
    if freqs.size==1:
        dfw[:] = 0.0
    else:
        dfw[1:-1] = 0.5*(freqs[2:] - freqs[:-2])
        dfw[0]    = freqs[1] - freqs[0]
        dfw[-1]   = freqs[-1] - freqs[-2]
        dfw = np.where(dfw<0, np.nan, dfw)

    mask_hig   = (freqs < 0.05)
    mask_swell = (freqs >= 0.05) & (freqs <= 0.14)
    mask_sea   = (freqs > 0.14) & (freqs <= 0.30)

    rows=[]
    for ln in data_lines:
        parts = re.split(r"\s+", ln.strip())
        if len(parts) < 4: continue
        y,mo,dy,hh = parts[:4]
        try:
            y=int(y); mo=int(mo); dy=int(dy); hh=int(hh)
            year = 2000+y if y<100 else y
            ts = pd.Timestamp(year=year, month=mo, day=dy, hour=hh, tz="UTC").tz_localize(None)
        except Exception:
            continue

        vals = np.array([_ndbc_to_float(v) for v in parts[4:]], dtype=float)
        if vals.size != freqs.size:
            if vals.size < freqs.size:
                vals = np.pad(vals, (0, freqs.size - vals.size), constant_values=np.nan)
            else:
                vals = vals[:freqs.size]

        def band_m0(mask):
            return np.nansum(vals[mask] * dfw[mask]) if mask.any() else np.nan

        m0_hig   = band_m0(mask_hig)
        m0_swell = band_m0(mask_swell)
        m0_sea   = band_m0(mask_sea)

        high_HIG = 4.0*np.sqrt(m0_hig)   if np.isfinite(m0_hig)   and m0_hig>=0   else np.nan
        Hswell   = 4.0*np.sqrt(m0_swell) if np.isfinite(m0_swell) and m0_swell>=0 else np.nan
        Hsea     = 4.0*np.sqrt(m0_sea)   if np.isfinite(m0_sea)   and m0_sea>=0   else np.nan

        rows.append((ts, high_HIG, Hswell, Hsea))

    if not rows: return pd.DataFrame()
    df_out = (
        pd.DataFrame(rows, columns=["time","high_HIG","Hswell","Hsea"])
        .sort_values("time")
        .drop_duplicates("time")
        .reset_index(drop=True)
    )
    return df_out

def build_swden(years: List[int]) -> pd.DataFrame:
    frames = []
    for y in tqdm(years, desc="SWDEN years", unit="yr"):
        txt = fetch_swden_year(y)
        if not txt: continue
        dfy = parse_swden_year(txt)
        if dfy.empty: continue
        frames.append(dfy)
    if not frames:
        return pd.DataFrame(columns=["time","high_HIG","Hswell","Hsea"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

df_swden = build_swden(YEARS)

# -------------------
# NOAA CO-OPS 9435380 wind (hourly) + HMSC fallback (+ baro)
# -------------------
def parse_noaa_wind_csv(txt:str)->pd.DataFrame:
    df = pd.read_csv(io.StringIO(txt))
    time_cols = [c for c in df.columns if str(c).strip().lower() in ("t","time","date time","date_time","date")]
    time_col = time_cols[0] if time_cols else df.columns[0]
    df["time"] = pd.to_datetime(df[time_col], errors="coerce", utc=True).dt.tz_localize(None)
    def _find(colnames, keys):
        keys = [k.lower() for k in keys]
        for c in colnames:
            s = str(c).lower()
            if any(k == s or k in s for k in keys): return c
        return None
    sp_col = _find(df.columns, ["s","speed"])
    dir_col= _find(df.columns, ["d","dir","direction"])
    out = pd.DataFrame({"time":df["time"]})
    out["wind9435380_speed"] = pd.to_numeric(df[sp_col], errors="coerce") if sp_col else np.nan
    out["wind9435380_dir"]   = pd.to_numeric(df[dir_col], errors="coerce") if dir_col else np.nan
    out.loc[(out["wind9435380_dir"]<0)|(out["wind9435380_dir"]>360),"wind9435380_dir"]=np.nan
    out.loc[(out["wind9435380_speed"]<0)|(out["wind9435380_speed"]>100),"wind9435380_speed"]=np.nan
    out = out.sort_values("time").dropna(subset=["time"]).drop_duplicates("time")
    return out.reset_index(drop=True)

def fetch_noaa_wind_dataframe(begin_date:str, end_date:str, station:str, days_per_chunk:int=365) -> pd.DataFrame:
    start = pd.to_datetime(begin_date, format="%Y%m%d")
    end   = pd.to_datetime(end_date,   format="%Y%m%d")
    chunks = []
    cur = start
    while cur <= end:
        chunk_end = min(end, cur + pd.Timedelta(days=days_per_chunk-1))
        chunks.append((cur, chunk_end))
        cur = chunk_end + pd.Timedelta(days=1)
    frames=[]
    for a, b in tqdm(chunks, desc=f"NOAA wind {station}", unit="chunk"):
        url = ("https://api.tidesandcurrents.noaa.gov/api/prod/datagetter"
               f"?begin_date={_datestr(a)}&end_date={_datestr(b)}"
               f"&station={station}&product=wind&time_zone=gmt&interval=h&units=metric"
               "&application=SeichePipeline&format=csv")
        try:
            r = session.get(url, timeout=60)
            if r.status_code==200 and r.content:
                dfc = parse_noaa_wind_csv(r.content.decode("utf-8","ignore"))
                if not dfc.empty: frames.append(dfc)
        except Exception:
            pass
    if not frames:
        return pd.DataFrame(columns=["time","wind9435380_speed","wind9435380_dir"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

def fetch_hmsc_month_text(year:int, month:int) -> str|None:
    yyyymm = f"{year}{month:02d}"; url = HMSC_BASE_URL.format(yyyy=year, yyyymm=yyyymm)
    try:
        r = session.get(url, timeout=60)
        if r.status_code==200 and r.content:
            return r.content.decode("utf-8","ignore")
    except Exception:
        pass
    return None

def parse_hmsc_dat(txt:str) -> pd.DataFrame:
    lines = txt.splitlines(); start_idx=None
    for i,ln in enumerate(lines):
        s = ln.strip().strip('"')
        if s.startswith("TIMESTAMP"): start_idx=i; break
    if start_idx is None: return pd.DataFrame()
    content = "\n".join(lines[start_idx:])
    df = pd.read_csv(io.StringIO(content))

    def _try_time(x):
        try: return pd.to_datetime(x)
        except: return pd.NaT
    tt = df.iloc[:,0].apply(_try_time)
    df = df.loc[tt.notna()].copy()
    df.rename(columns={df.columns[0]:"TIMESTAMP"}, inplace=True)

    cols_lower = {str(c).strip().lower(): c for c in df.columns}

    AWS_col = None; AWD_col = None
    for k in cols_lower:
        if k in ("aws",): AWS_col = cols_lower[k]
        if k in ("awd",): AWD_col = cols_lower[k]

    BARO_col = None
    for key in BARO_COL_KEYS:
        for k in cols_lower:
            if key == k or key in k:
                BARO_col = cols_lower[k]; break
        if BARO_col is not None: break

    out = pd.DataFrame()
    out["time"] = pd.to_datetime(df["TIMESTAMP"], errors="coerce") + pd.Timedelta(hours=HMSC_UTC_OFFSET_H)

    if AWS_col is not None and AWD_col is not None:
        mph = pd.to_numeric(df[AWS_col], errors="coerce")
        spd_ms = mph * 0.44704
        dir_deg = pd.to_numeric(df[AWD_col], errors="coerce")
        dir_deg[(dir_deg<0)|(dir_deg>360)] = np.nan
        out["hmsc_speed_ms"]=spd_ms.astype("float32")
        out["hmsc_dir_deg"]=dir_deg.astype("float32")

    if BARO_col is not None:
        bp = pd.to_numeric(df[BARO_col], errors="coerce").astype("float32")
        bp[(bp < BARO_CLIP_MIN_HPA) | (bp > BARO_CLIP_MAX_HPA)] = np.nan
        out["hmsc_baro_hpa"] = bp

    out = out.sort_values("time").dropna(subset=["time"]).drop_duplicates("time")
    return out.reset_index(drop=True)

def month_iter(start:pd.Timestamp, end:pd.Timestamp):
    y,m = start.year, start.month
    while pd.Timestamp(year=y, month=m, day=1) <= end:
        yield (y,m)
        if m==12: y,m=y+1,1
        else: m+=1

def fetch_hmsc_range(start:pd.Timestamp, end:pd.Timestamp)->pd.DataFrame:
    months = list(month_iter(start.normalize(), end.normalize()))
    frames=[]
    for y,m in tqdm(months, desc="HMSC months", unit="mo"):
        txt = fetch_hmsc_month_text(y,m)
        if not txt: continue
        dfm = parse_hmsc_dat(txt)
        if not dfm.empty:
            dfm = dfm[(dfm["time"]>=start-pd.Timedelta(days=1))&(dfm["time"]<=end+pd.Timedelta(days=1))]
            frames.append(dfm)
    if not frames:
        return pd.DataFrame(columns=["time","hmsc_speed_ms","hmsc_dir_deg","hmsc_baro_hpa"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    return df.reset_index(drop=True)

def resample_hmsc_hourly(df_hmsc:pd.DataFrame)->pd.DataFrame:
    if df_hmsc.empty: return df_hmsc
    g = df_hmsc.set_index("time").sort_index()
    hourly = pd.DataFrame(index=g.index.floor("h").unique())
    for col in ["hmsc_speed_ms","hmsc_dir_deg","hmsc_baro_hpa"]:
        if col in g.columns:
            hourly[col] = g[col].resample("h").mean()
    out = hourly.reset_index().rename(columns={"index":"time"}).dropna(subset=["time"])
    return out.reset_index(drop=True)

def baro_highpass_variance(df_hmsc_hour:pd.DataFrame,
                           hp_hours:float=BARO_VAR_HIGHPASS_HOURS,
                           var_hours:float=BARO_VAR_WIN_HOURS)->pd.DataFrame:
    from scipy.signal import butter, filtfilt
    if df_hmsc_hour.empty or "hmsc_baro_hpa" not in df_hmsc_hour.columns:
        return pd.DataFrame(columns=["time","baro_var4h"])
    df = df_hmsc_hour.dropna(subset=["time","hmsc_baro_hpa"]).copy().sort_values("time")
    s = df.set_index("time")["hmsc_baro_hpa"].astype(float)
    fs = 1.0  # hourly
    fc = 1.0 / max(hp_hours, 0.5)
    wn = min(max(fc / (fs/2.0), 1e-4), 0.999)
    b,a = butter(4, wn, btype="highpass")
    x = s.to_numpy(dtype=float)
    if np.isnan(x).any():
        idx = np.arange(len(x)); good = np.isfinite(x)
        x[~good] = np.interp(idx[~good], idx[good], x[good])
    xf = filtfilt(b, a, x, method="gust")
    win = max(3, int(round(var_hours)))
    minp = max(2, int(0.8 * win))
    var_series = pd.Series(xf, index=s.index).rolling(window=win, center=True, min_periods=minp).var()
    out = pd.DataFrame({"time": var_series.index, "baro_var4h": var_series.astype("float32")}).dropna(subset=["time"])
    return out.reset_index(drop=True)

# Build wind + baro with fill
df_noaa_wind = fetch_noaa_wind_dataframe(NOAA_BEGIN_DATE, NOAA_END_DATE, NOAA_WIND_STATION, days_per_chunk=NOAA_CHUNK_DAYS)
if len(df_noaa_wind):
    fill_start, fill_end = df_noaa_wind["time"].min(), df_noaa_wind["time"].max()
else:
    candidates=[]
    if len(df_swden): candidates.append((df_swden["time"].min(), df_swden["time"].max()))
    if len(df_stdmet): candidates.append((df_stdmet["time"].min(), df_stdmet["time"].max()))
    if candidates:
        fill_start = min(a for a,_ in candidates); fill_end = max(b for _,b in candidates)
    else:
        fill_start = pd.Timestamp("2008-01-01"); fill_end=pd.Timestamp("2025-10-01")

df_hmsc_raw  = fetch_hmsc_range(fill_start, fill_end)
df_hmsc_hour = resample_hmsc_hourly(df_hmsc_raw)
df_baro_var  = baro_highpass_variance(df_hmsc_hour, hp_hours=BARO_VAR_HIGHPASS_HOURS, var_hours=BARO_VAR_WIN_HOURS)

# ---- NOAA/HMSC hourly merge with per-column fill (NOAA priority), speed/dir only ----
hourly_grid = pd.date_range(fill_start.floor("h"), fill_end.ceil("h"), freq="h")

noaa = (df_noaa_wind.set_index("time").reindex(hourly_grid).rename_axis("time").reset_index()) \
        if len(df_noaa_wind) else pd.DataFrame({"time":hourly_grid})
noaa.rename(columns={
    "wind9435380_speed": "speed",
    "wind9435380_dir":   "dir",
}, inplace=True)

if len(df_hmsc_hour):
    hmsc = (df_hmsc_hour.set_index("time").reindex(hourly_grid).rename_axis("time").reset_index())
else:
    hmsc = pd.DataFrame({
        "time": hourly_grid,
        "hmsc_speed_ms": np.nan,
        "hmsc_dir_deg":  np.nan,
        "hmsc_baro_hpa": np.nan,
    })

for out_col, src_col in [("speed", "hmsc_speed_ms"), ("dir", "hmsc_dir_deg")]:
    if out_col not in noaa.columns:
        noaa[out_col] = np.nan
    if src_col in hmsc.columns:
        m = noaa[out_col].isna() & hmsc[src_col].notna()
        noaa.loc[m, out_col] = hmsc.loc[m, src_col].values

df_wind_filled = pd.DataFrame({
    "time": noaa["time"],
    "wind9435380_speed": pd.to_numeric(noaa["speed"], errors="coerce").astype("float32"),
    "wind9435380_dir":   pd.to_numeric(noaa["dir"],   errors="coerce").astype("float32"),
}).dropna(subset=["time"]).reset_index(drop=True)

# -------------------
# NOAA 9435380 water level (6-min) → bandpass(20–30 min) → centered variance(2h)
# -------------------
from scipy.signal import butter, filtfilt

def parse_noaa_wl_csv(txt:str)->pd.DataFrame:
    if txt.strip().lower().startswith("error"):
        return pd.DataFrame(columns=["time","wl"])
    df = pd.read_csv(io.StringIO(txt))
    if df.shape[1] < 2:
        return pd.DataFrame(columns=["time","wl"])
    df.columns = [str(c).strip().lower().replace("  ", " ") for c in df.columns]
    tcol = None
    for k in ("date time","date_time","time","t","date"):
        if k in df.columns: tcol=k; break
    if tcol is None:
        return pd.DataFrame(columns=["time","wl"])
    df["time"] = pd.to_datetime(df[tcol], errors="coerce", utc=True).dt.tz_localize(None)
    vcol = None
    for k in ("water level","water_level","wl","v"):
        if k in df.columns: vcol=k; break
    if vcol is None:
        vcol = df.columns[1]
    df["wl"] = pd.to_numeric(df[vcol], errors="coerce")
    out = df[["time","wl"]].dropna(subset=["time"]).sort_values("time").drop_duplicates("time")
    out.loc[(out["wl"]<-50)|(out["wl"]>50),"wl"]=np.nan
    if out["wl"].notna().sum() == 0:
        return pd.DataFrame(columns=["time","wl"])
    return out.reset_index(drop=True)

def fetch_noaa_wl(begin_date:str, end_date:str, station:str, days_per_chunk:int=31)->pd.DataFrame:
    start = pd.to_datetime(begin_date, format="%Y%m%d")
    end   = pd.to_datetime(end_date,   format="%Y%m%d")
    chunks=[]; cur=start
    while cur <= end:
        chunk_end = min(end, cur + pd.Timedelta(days=days_per_chunk-1))
        chunks.append((cur, chunk_end))
        cur = chunk_end + pd.Timedelta(days=1)
    frames=[]
    for a,b in tqdm(chunks, desc=f"NOAA WL {station}", unit="chunk"):
        url = ("https://api.tidesandcurrents.noaa.gov/api/prod/datagetter"
               f"?begin_date={_datestr(a)}&end_date={_datestr(b)}"
               f"&station={station}&product=water_level&datum=MLLW&time_zone=gmt&units=metric"
               f"&interval=6&application=SeichePipeline&format=csv")
        try:
            r = session.get(url, timeout=90)
            if r.status_code==200 and r.content:
                dfc = parse_noaa_wl_csv(r.content.decode("utf-8","ignore"))
                if not dfc.empty:
                    frames.append(dfc)
        except Exception:
            pass
    if not frames:
        print(f"[WL] Station {station} returned 0 rows between {begin_date} and {end_date}.")
        return pd.DataFrame(columns=["time","wl"])
    df = pd.concat(frames, ignore_index=True).sort_values("time").drop_duplicates("time")
    print(f"[WL] Station {station} rows: {len(df):,}  range: {df['time'].min()} → {df['time'].max()}")
    return df.reset_index(drop=True)

if WL_USE_FETCH:
    df_wl6 = fetch_noaa_wl(WL_BEGIN_DATE, WL_END_DATE, NOAA_WL_STATION, days_per_chunk=31)
else:
    try:
        df_wl6
    except NameError:
        df_wl6 = pd.DataFrame(columns=["time","wl"])

def bandpass_wl_var(df_wl:pd.DataFrame,
                    dt_minutes:int,
                    var_win_hours:float) -> pd.DataFrame:
    """Bandpass BANDPASS_MAX_MIN–BANDPASS_MIN_MIN minutes (order BP_ORDER), zero-phase,
       then centered rolling variance with window = var_win_hours."""
    if df_wl.empty:
        return pd.DataFrame(columns=["time","var1h"])
    t0, t1 = df_wl["time"].min(), df_wl["time"].max()
    grid = pd.date_range(t0.floor(f"{dt_minutes}min"), t1.ceil(f"{dt_minutes}min"), freq=f"{dt_minutes}min")
    s = df_wl.set_index("time")["wl"].reindex(grid)
    s = s.interpolate(method="time", limit_area="inside")
    dt_sec = dt_minutes * 60.0
    fs = 1.0 / dt_sec
    f_lo = 1.0 / (BANDPASS_MAX_MIN * 60.0)   # e.g., 30 min
    f_hi = 1.0 / (BANDPASS_MIN_MIN * 60.0)   # e.g., 20 min
    wn = [f_lo/(fs/2.0), f_hi/(fs/2.0)]
    wn = [max(1e-6, min(0.999, w)) for w in wn]
    b,a = butter(BP_ORDER, wn, btype="bandpass")
    x = s.to_numpy(dtype=float)
    if np.all(np.isnan(x)):
        return pd.DataFrame(columns=["time","var1h"])
    if np.isnan(x).any():
        idx = np.arange(len(x)); good = np.isfinite(x)
        x[~good] = np.interp(idx[~good], idx[good], x[good])
    x_bp = filtfilt(b, a, x, method="gust")
    win = max(3, int(round((var_win_hours * 60.0) / dt_minutes)))
    minp = max(2, int(0.8 * win))
    var_series = pd.Series(x_bp, index=grid).rolling(window=win, center=True, min_periods=minp).var()
    out = pd.DataFrame({"time": grid, "var1h": var_series.astype("float32")}).dropna(subset=["time"]).reset_index(drop=True)
    return out

df_var = bandpass_wl_var(df_wl6, DT_MINUTES, VAR_WIN_HOURS)
if df_var.empty:
    raise RuntimeError("var1h is empty — WL series could not be fetched/parsed or contained no finite values.")

# -------------------
# Build common grid; interpolate inputs to that grid
# -------------------
def interp_timegrid(t_src:pd.Series, x_src:np.ndarray, t_grid:pd.DatetimeIndex)->np.ndarray:
    s = pd.Series(x_src, index=pd.to_datetime(t_src))
    u = s.reindex(pd.to_datetime(sorted(set(s.index).union(set(t_grid)))))
    u = u.interpolate(method="time", limit_area="inside")
    return u.reindex(pd.to_datetime(t_grid)).to_numpy(dtype=float)

def interp_dir_circular(t_src, dir_deg, t_grid):
    """Interpolate directions on a circle via sin/cos then atan2; returns 0–360 deg."""
    t_src = pd.to_datetime(t_src)
    dir_arr = np.asarray(dir_deg, dtype=float)
    theta = np.deg2rad(dir_arr)
    s = np.sin(theta); c = np.cos(theta)
    s_ser = pd.Series(s, index=t_src).sort_index()
    c_ser = pd.Series(c, index=t_src).sort_index()
    union_idx = pd.to_datetime(sorted(set(s_ser.index).union(set(t_grid))))
    s_u = s_ser.reindex(union_idx).interpolate("time", limit_area="inside")
    c_u = c_ser.reindex(union_idx).interpolate("time", limit_area="inside")
    s_i = s_u.reindex(t_grid).to_numpy(dtype=float)
    c_i = c_u.reindex(t_grid).to_numpy(dtype=float)
    r = np.hypot(c_i, s_i); m = r>0
    c_i[m] /= r[m]; s_i[m] /= r[m]
    return (np.degrees(np.arctan2(s_i, c_i)) % 360.0).astype(float)

# Overlap window across response & inputs
t0_candidates = [df_var["time"].min()]
t1_candidates = [df_var["time"].max()]
if len(df_swden):
    t0_candidates.append(df_swden["time"].min()); t1_candidates.append(df_swden["time"].max())
if len(df_stdmet):
    t0_candidates.append(df_stdmet["time"].min()); t1_candidates.append(df_stdmet["time"].max())
if len(df_wind_filled):
    t0_candidates.append(df_wind_filled["time"].min()); t1_candidates.append(df_wind_filled["time"].max())
if len(df_hmsc_hour):
    t0_candidates.append(df_hmsc_hour["time"].min()); t1_candidates.append(df_hmsc_hour["time"].max())
if len(df_baro_var):
    t0_candidates.append(df_baro_var["time"].min()); t1_candidates.append(df_baro_var["time"].max())

t0 = max(t0_candidates); t1 = min(t1_candidates)
t_grid = pd.date_range(t0, t1, freq=f"{DT_MINUTES}min")

seq_mat = {}
# Response on grid
var_on_grid = interp_timegrid(df_var["time"], df_var["var1h"].values, t_grid)

# SWDEN bands
if len(df_swden):
    if "Hswell" in df_swden.columns:
        seq_mat["Hswell"] = interp_timegrid(df_swden["time"], df_swden["Hswell"].values, t_grid)
    elif "Hm0_swell" in df_swden.columns:
        seq_mat["Hswell"] = interp_timegrid(df_swden["time"], df_swden["Hm0_swell"].values, t_grid)
    else:
        seq_mat["Hswell"] = np.full(len(t_grid), np.nan)
    if "Hsea" in df_swden.columns:
        seq_mat["Hsea"] = interp_timegrid(df_swden["time"], df_swden["Hsea"].values, t_grid)
    elif "Hm0_sea" in df_swden.columns:
        seq_mat["Hsea"] = interp_timegrid(df_swden["time"], df_swden["Hm0_sea"].values, t_grid)
    else:
        seq_mat["Hsea"] = np.full(len(t_grid), np.nan)
    if "high_HIG" in df_swden.columns:
        seq_mat["high_HIG"] = interp_timegrid(df_swden["time"], df_swden["high_HIG"].values, t_grid)
else:
    seq_mat["Hswell"]   = np.full(len(t_grid), np.nan)
    seq_mat["Hsea"]     = np.full(len(t_grid), np.nan)
    seq_mat["high_HIG"] = np.full(len(t_grid), np.nan)

# STDMET waves
for c in ["Hs","DPD","APD","MWD"]:
    if len(df_stdmet):
        seq_mat[c] = interp_timegrid(df_stdmet["time"], df_stdmet[c].values, t_grid)
    else:
        seq_mat[c] = np.full(len(t_grid), np.nan)

# Wind (hourly → DT): circular interpolate direction; scale speed by WIND_WEIGHT
if len(df_wind_filled):
    sp6  = interp_timegrid(df_wind_filled["time"], df_wind_filled["wind9435380_speed"].values, t_grid)
    dir6 = interp_dir_circular(df_wind_filled["time"], df_wind_filled["wind9435380_dir"].values, t_grid)
    seq_mat["wind9435380_speed"] = (sp6 * float(WIND_WEIGHT)).astype(float)
    seq_mat["wind9435380_dir"]   = dir6.astype(float)

# HMSC baro + variance
if len(df_hmsc_hour) and "hmsc_baro_hpa" in df_hmsc_hour.columns:
    seq_mat["hmsc_baro_hpa"] = interp_timegrid(df_hmsc_hour["time"], df_hmsc_hour["hmsc_baro_hpa"].values, t_grid)
else:
    seq_mat["hmsc_baro_hpa"] = np.full(len(t_grid), np.nan)
if len(df_baro_var):
    seq_mat["baro_var4h"] = interp_timegrid(df_baro_var["time"], df_baro_var["baro_var4h"].values, t_grid)
else:
    seq_mat["baro_var4h"] = np.full(len(t_grid), np.nan)

# ---- Unified gap mask (optional) ----
def gap_intervals(t:pd.Series, x:np.ndarray, min_gap_hours:float)->List[Tuple[pd.Timestamp,pd.Timestamp]]:
    t = pd.to_datetime(t)
    dt = np.diff(t.values).astype("timedelta64[s]").astype(int)
    jumps = np.where(dt > min_gap_hours*3600)[0]
    iv = [(t[j], t[j+1]) for j in jumps]
    s = pd.Series(x, index=t)
    isn = s.isna().to_numpy()
    if isn.any():
        starts = np.where(np.diff(np.r_[False, isn])==1)[0]
        ends   = np.where(np.diff(np.r_[isn, False])==-1)[0]
        for a,b in zip(starts, ends):
            if (t[b]-t[a]) >= pd.Timedelta(hours=min_gap_hours):
                iv.append((t[a], t[b]))
    return iv

def merge_intervals(intervals):
    if not intervals: return []
    z = sorted(intervals, key=lambda k: k[0]); out = [list(z[0])]
    for s,e in z[1:]:
        if s <= out[-1][1]: out[-1][1] = max(out[-1][1], e)
        else: out.append([s,e])
    return [(pd.to_datetime(a), pd.to_datetime(b)) for a,b in out]

if USE_UNIFIED_GAP_MASK:
    all_intervals=[]
    if len(df_swden):
        for c in ["Hswell","Hsea","high_HIG"]:
            if c in df_swden.columns:
                all_intervals += gap_intervals(df_swden["time"], df_swden[c].values, MAX_GAP_HOURS)
    if len(df_stdmet):
        for c in ["Hs","DPD","APD","MWD"]:
            all_intervals += gap_intervals(df_stdmet["time"], df_stdmet[c].values, MAX_GAP_HOURS)
    if len(df_wind_filled):
        for c in ["wind9435380_speed","wind9435380_dir"]:
            all_intervals += gap_intervals(df_wind_filled["time"], df_wind_filled[c].values, MAX_GAP_HOURS)
    if len(df_hmsc_hour):
        for c in ["hmsc_baro_hpa"]:
            all_intervals += gap_intervals(df_hmsc_hour["time"], df_hmsc_hour[c].values, MAX_GAP_HOURS)
    merged_gaps = merge_intervals(all_intervals)
    gap_mask = np.zeros(len(t_grid), dtype=bool)
    for s,e in merged_gaps:
        gap_mask |= (t_grid>=s) & (t_grid<=e)
else:
    gap_mask = np.zeros(len(t_grid), dtype=bool)

# Assemble df_all
df_all = pd.DataFrame({"time": t_grid, "var1h": var_on_grid})
for c,v in seq_mat.items():
    df_all[c] = v
df_all = df_all.loc[~gap_mask].reset_index(drop=True)

# Clean/Interpolate some inputs post-align
for col, thresh in [("Hs",50.0),("DPD",50.0),("APD",50.0)]:
    if col in df_all.columns:
        m = df_all[col] > thresh
        if m.any(): df_all.loc[m, col] = np.nan
        df_all[col] = pd.Series(df_all[col].values, index=df_all["time"]).interpolate("time", limit_area="inside").values

# Feature placeholder
df_all["wl_gradient"] = np.nan

# -------------------
# Lag candidate helpers
# -------------------
def is_eq_base(name: str) -> bool:
    n = name.lower()
    return n.startswith("eq") or ("eq_" in n)

def build_lag_candidates(df: pd.DataFrame,
                         bases: List[str],
                         dt_minutes: int,
                         step_minutes: int = 30) -> Tuple[pd.DataFrame, List[str], Dict[str, str], Dict[str, int]]:
    lag_cols, series = [], {}
    base_map, lag_minutes_map = {}, {}
    for b in bases:
        x = df[b].astype("float32").to_numpy()
        if not np.isfinite(x).any():
            continue
        max_lag_min = EQ_MAX_LAG_MIN if is_eq_base(b) else NON_EQ_MAX_LAG_MIN
        grid_min = list(range(0, max_lag_min + 1, int(step_minutes)))
        for L in grid_min:
            steps = int(round(L / dt_minutes))
            nm = f"{b}__lag_{L}min"
            lag_minutes_map[nm] = L
            base_map[nm] = b
            if steps == 0:
                series[nm] = x
            else:
                s = pd.Series(x)
                series[nm] = s.shift(steps).to_numpy(dtype="float32")
            lag_cols.append(nm)
    lag_df = pd.DataFrame(series).reindex(range(len(df))).reset_index(drop=True)
    return lag_df, lag_cols, base_map, lag_minutes_map

# Build lags from df_all
exclude_cols = set(["time", "var1h", "var1h_log1p", "wsp", "wdir", "gust", "atmvar1_boost"])
cand_bases = [c for c in df_all.columns if (c not in exclude_cols and df_all[c].dtype != object)]
exclude_cols = set(["time","var1h","var1h_log1p","wsp","wdir","gust","atmvar1_boost","Hs"])  # <— add Hs here
cand_bases = [c for c in df_all.columns if (c not in exclude_cols and df_all[c].dtype != object)]
lag_df, lag_cols, base_map, lag_minutes_map = build_lag_candidates(df_all, cand_bases, DT_MINUTES)

# -------------------
# XGBoost ranking: pick ONE best lag per base
# -------------------
from xgboost import XGBRegressor

h_steps = max(1, int(round(PRIMARY_HORIZON / DT_MINUTES)))
tgt = df_all["var1h"].shift(-h_steps).reset_index(drop=True).astype("float32")
rank_data = pd.concat([lag_df, tgt.rename("target")], axis=1).dropna().reset_index(drop=True)

if rank_data.empty:
    raise RuntimeError("No data for XGB ranking after lagging. Check df_all coverage and PRIMARY_HORIZON.")

X_rank = rank_data[lag_cols].to_numpy(dtype=np.float32)
y_rank = rank_data["target"].to_numpy(dtype=np.float32)

xgb_imp = XGBRegressor(
    n_estimators=min(300, XGB_N_ESTIMATORS//2), learning_rate=XGB_LEARNING_RATE,
    max_depth=XGB_MAX_DEPTH, subsample=XGB_SUBSAMPLE, colsample_bytree=XGB_COLSAMPLE,
    reg_lambda=XGB_L2, objective="reg:squarederror", n_jobs=-1, random_state=42, verbosity=XGB_VERBOSE
)
split = max(1, int(0.8 * len(X_rank)))
xgb_imp.fit(X_rank[:split], y_rank[:split], verbose=False)
importances = dict(zip(lag_cols, xgb_imp.feature_importances_))

best_lag_per_base = {}
for nm, imp in importances.items():
    b = base_map[nm]
    prev = best_lag_per_base.get(b)
    if (prev is None) or (imp > prev[1]):
        best_lag_per_base[b] = (nm, float(imp))

best_list = [(b, nm, imp) for b, (nm, imp) in best_lag_per_base.items()]
best_list.sort(key=lambda t: t[2], reverse=True)
best_list = best_list[:TOP_LAG_FEATURES]
selected_lag_cols = [nm for _, nm, _ in best_list]
selected_bases = [b for b, _, _ in best_list]

print("\nSelected ONE best lag per base:")
for b, nm, imp in best_list:
    print(f"  {nm:<36s} (base={b}, importance={imp:.6f}, lag={lag_minutes_map[nm]} min)")

df_lag_all = pd.concat(
    [df_all[["time", "var1h"]].reset_index(drop=True),
     lag_df[selected_lag_cols].reset_index(drop=True)],
    axis=1
)

# Keep your best_list logic... then:
selected_lag_cols = [nm for _, nm, _ in best_list]
selected_bases    = [b  for b, _, _ in best_list]

def _best_existing_lag_for(base_name:str, candidate_lags_min:list[int])->str|None:
    # prefer specific lags if available
    for L in candidate_lags_min:
        nm = f"{base_name}__lag_{L}min"
        if nm in lag_cols:
            return nm
    # otherwise, find any lag for this base with max importance
    cands = [nm for nm in lag_cols if base_map.get(nm)==base_name]
    if not cands: return None
    return max(cands, key=lambda z: importances.get(z, 0.0))

if BARO_FORCE_IN_XGB:
    for base in ["baro_var4h", "hmsc_baro_hpa"]:
        nm = _best_existing_lag_for(base, BARO_FORCE_LAGS_MIN)
        if nm and nm not in selected_lag_cols:
            selected_lag_cols.append(nm)
            selected_bases.append(base)


# ============================================================
# Mini-TFT (seq-to-one) — compact training loop with nonneg output
# ============================================================
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

use_gpu = torch.cuda.is_available()
if use_gpu:
    torch.backends.cudnn.benchmark = True
    if hasattr(torch, "set_float32_matmul_precision"):
        torch.set_float32_matmul_precision("high")

def _amp_ctx():
    return nullcontext()

class GLU(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.v = nn.Linear(d, d)
        self.g = nn.Linear(d, d)
    def forward(self, x):
        return torch.sigmoid(self.g(x)) * self.v(x)

class MiniTFT(nn.Module):
    def __init__(self, num_features:int, d_model:int=TFT_D_MODEL, nhead:int=TFT_NHEAD,
                 lstm_layers:int=TFT_LSTM_LAYERS, ff_dim:int=TFT_FF_DIM, dropout:float=TFT_DROPOUT,
                 nonneg:bool=TFT_NONNEG_OUTPUT, var_floor:float=TFT_VAR_FLOOR):
        super().__init__()
        self.F = num_features
        self.d = d_model
        self.nonneg = nonneg
        self.var_floor = float(var_floor)
        self.feat_emb = nn.ModuleList([nn.Linear(1, d_model) for _ in range(num_features)])
        self.vsn_w = nn.Linear(num_features, num_features)
        self.lstm = nn.LSTM(input_size=d_model, hidden_size=d_model,
                            num_layers=lstm_layers, batch_first=True, dropout=0.0, bidirectional=False)
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.proj = nn.Sequential(
            nn.Linear(2*d_model, ff_dim), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model), GLU(d_model),
            nn.Linear(d_model, 1)
        )
        self.out_act = nn.Softplus() if self.nonneg else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.last_alpha = None
    def forward(self, x):  # (B,T,F)
        B,T,F = x.shape
        x4 = x.unsqueeze(-1)
        embs = [ self.feat_emb[f](x4[:,:,f,:]) for f in range(self.F) ]
        E = torch.stack(embs, dim=2)            # (B,T,F,d)
        logits = self.vsn_w(x)                  # (B,T,F)
        alpha = torch.softmax(logits, dim=2).unsqueeze(-1)
        self.last_alpha = alpha.detach()
        Z = (alpha * E).sum(dim=2)              # (B,T,d)
        Z = self.dropout(Z)
        H, _ = self.lstm(Z)
        last = H[:, -1:, :]
        ctx, _ = self.attn(query=last, key=H, value=H, need_weights=False)
        cat = torch.cat([last, ctx], dim=-1).squeeze(1)
        out = self.proj(cat).squeeze(-1)
        out = self.out_act(out)
        if self.nonneg and self.var_floor > 0:
            out = torch.clamp(out, min=self.var_floor)
        return out

def _safe_robust_params(X: np.ndarray, q_low=25.0, q_high=75.0, eps=1e-6):
    Xf = X.reshape(-1, X.shape[2])
    med = np.nanmedian(Xf, axis=0)
    q1  = np.nanpercentile(Xf, q_low, axis=0)
    q3  = np.nanpercentile(Xf, q_high, axis=0)
    iqr = q3 - q1
    scale = np.where(iqr <= eps, 1.0, iqr)
    return med.astype(np.float32), scale.astype(np.float32)

def fit_scaler_from_windows_safe(X_tr, X_va=None, X_te=None):
    for X in (X_tr, X_va, X_te):
        if X is not None and len(X) > 0:
            med, scale = _safe_robust_params(X)
            return {"center": med, "scale": scale}
    F = X_tr.shape[2] if (X_tr is not None and X_tr.size) else (
        X_va.shape[2] if (X_va is not None and X_va.size) else (
        X_te.shape[2] if (X_te is not None and X_te.size) else 1))
    return {"center": np.zeros(F, np.float32), "scale": np.ones(F, np.float32)}

def norm_windows(X, rs, clip=20.0):
    if len(X) == 0:
        return X
    F = X.shape[2]
    C = rs["center"].reshape(1, 1, F)
    S = rs["scale" ].reshape(1, 1, F)
    Z = (X - C) / S
    Z = np.nan_to_num(Z, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    if clip is not None:
        Z = np.clip(Z, -clip, clip, out=Z)
    return Z

def make_windows_from_lagged(df:pd.DataFrame, lag_cols:List[str],
                             history_hours:int, horizon_min:int, dt_minutes:int,
                             stride:int=1):
    steps_hist  = int(round(history_hours*60/dt_minutes))
    steps_ahead = int(round(horizon_min/dt_minutes))
    tab = df[["time","var1h"] + lag_cols].copy()
    tab = tab.dropna(subset=lag_cols + ["var1h"]).reset_index(drop=True)
    if len(tab) < steps_hist + steps_ahead + 1:
        return np.empty((0,steps_hist,len(lag_cols))), np.empty((0,)), pd.DatetimeIndex([])
    X_list, y_list, t_list = [], [], []
    F = len(lag_cols); Ttot = len(tab)
    for end_idx in range(steps_hist, Ttot - steps_ahead + 1, stride):
        s = end_idx - steps_hist
        k = end_idx - 1 + steps_ahead
        xb = tab.iloc[s:end_idx][lag_cols].values
        yv = tab.iloc[k]["var1h"]
        if np.any(~np.isfinite(xb)) or not np.isfinite(yv):
            continue
        X_list.append(xb.astype(np.float32))
        y_list.append(float(yv))
        t_list.append(pd.to_datetime(tab.iloc[k]["time"]))
    if not X_list:
        return np.empty((0,steps_hist,F)), np.empty((0,)), pd.DatetimeIndex([])
    return np.stack(X_list), np.array(y_list, dtype=np.float32), pd.DatetimeIndex(t_list)

def adaptive_index_split(X, y, t_idx, train_frac=0.70, val_frac=0.15):
    n = len(X)
    if n < 10:
        return (X[:0], y[:0], t_idx[:0]), (X[:0], y[:0], t_idx[:0]), (X, y, t_idx)
    a = max(1, int(round(train_frac * n)))
    b = max(a+1, int(round((train_frac + val_frac) * n)))
    b = min(b, n-1)
    return (X[:a], y[:a], t_idx[:a]), (X[a:b], y[a:b], t_idx[a:b]), (X[b:], y[b:], t_idx[b:])

def cap_split(X_, y_, t_, cap=MAX_SAMPLES_SPLIT):
    if len(X_) <= cap:
        return X_, y_, t_
    idx = np.linspace(0, len(X_) - 1, cap, dtype=int)
    return X_[idx], y_[idx], t_[idx]

def _fmt(ts):
    return ts.strftime("%Y-%m-%d %H:%M") if (ts is not None and pd.notna(ts)) else "—"

def print_split_summary(horizon, t_tr, t_va, t_te):
    def rng(t):
        if len(t)==0:
            return (None,None,0)
        return (t.min(), t.max(), len(t))
    tr0,tr1,ntr = rng(t_tr); va0,va1,nva = rng(t_va); te0,te1,nte = rng(t_te)
    print(f"[{horizon}m] Split ranges:")
    print(f"  Train: {_fmt(tr0)} → {_fmt(tr1)}  (n={ntr})")
    print(f"  Valid: {_fmt(va0)} → {_fmt(va1)}  (n={nva})")
    print(f"  Test : {_fmt(te0)} → {_fmt(te1)}  (n={nte})")
    return (tr0,tr1),(va0,va1),(te0,te1)

# ---- Train loop (per horizon) ----
results = {}

for horizon in HORIZONS_MIN:
    print(f"\n=== Horizon {horizon} min ===")
    X, y, t_idx = make_windows_from_lagged(df_lag_all, selected_lag_cols,
                                           HISTORY_HOURS, horizon, DT_MINUTES, stride=1)
    if len(t_idx):
        print(f"[{horizon}m] Windows available:", t_idx.min(), "→", t_idx.max(), f"(n={len(t_idx):,})")
    else:
        print(f"[{horizon}m] No windows after lagging/masking.")
        continue

    # Time split
    mtr = t_idx <= TRAIN_END
    mva = (t_idx > TRAIN_END) & (t_idx <= VAL_END)
    mte = t_idx > VAL_END
    X_tr, y_tr, t_tr = X[mtr], y[mtr], t_idx[mtr]
    X_va, y_va, t_va = X[mva], y[mva], t_idx[mva]
    X_te, y_te, t_te = X[mte], y[mte], t_idx[mte]
    if len(X_tr)==0 or len(X_va)==0 or len(X_te)==0:
        print("  Time split empty; using adaptive 70/15/15.")
        (X_tr, y_tr, t_tr), (X_va, y_va, t_va), (X_te, y_te, t_te) = adaptive_index_split(X, y, t_idx)
    if len(X_te)==0:
        n=len(X); cut=max(1,int(0.85*n))
        X_tr,y_tr,t_tr = X[:cut],y[:cut],t_idx[:cut]
        X_va,y_va,t_va = X[:0],y[:0],t_idx[:0]
        X_te,y_te,t_te = X[cut:],y[cut:],t_idx[cut:]

    X_tr, y_tr, t_tr = cap_split(X_tr, y_tr, t_tr, cap=MAX_SAMPLES_SPLIT)
    X_va, y_va, t_va = cap_split(X_va, y_va, t_va, cap=MAX_SAMPLES_SPLIT)
    X_te, y_te, t_te = cap_split(X_te, y_te, t_te, cap=MAX_SAMPLES_SPLIT)

    _ = print_split_summary(horizon, t_tr, t_va, t_te)

        # --- Target transform for TFT ---
    if TFT_TARGET_LOG1P:
        y_tr_t = np.log1p(np.maximum(y_tr, 0.0))
        y_va_t = np.log1p(np.maximum(y_va, 0.0))
        y_te_t = np.log1p(np.maximum(y_te, 0.0))
    else:
        y_tr_t, y_va_t, y_te_t = y_tr, y_va, y_te


    # Normalize windows for TFT
    rs     = fit_scaler_from_windows_safe(X_tr, X_va, X_te)
    X_tr_n = norm_windows(X_tr, rs); X_va_n = norm_windows(X_va, rs); X_te_n = norm_windows(X_te, rs)

    # --- Datasets / Loaders (use transformed targets y_*_t) ---
    device = torch.device("cuda" if use_gpu else "cpu")
    model = MiniTFT(
        num_features=X_tr_n.shape[2],
        d_model=TFT_D_MODEL, nhead=TFT_NHEAD,
        lstm_layers=TFT_LSTM_LAYERS, ff_dim=TFT_FF_DIM, dropout=TFT_DROPOUT,
        nonneg=TFT_NONNEG_OUTPUT, var_floor=TFT_VAR_FLOOR
    ).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=TFT_LR)

    ds_tr = TensorDataset(torch.from_numpy(X_tr_n), torch.from_numpy(y_tr_t))
    ds_va = TensorDataset(torch.from_numpy(X_va_n), torch.from_numpy(y_va_t))
    dl_tr = DataLoader(ds_tr, batch_size=min(TFT_BATCH, len(ds_tr)), shuffle=True,  drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=min(TFT_BATCH, len(ds_va)), shuffle=False, drop_last=False)

    best = np.inf; best_state=None; bad=0
    train_curve=[]; val_curve=[]

    # Spike threshold for weighting computed in TRANSFORMED space
    q = np.quantile(y_tr_t, BARO_SPIKE_Q) if (WEIGHT_SPIKES and len(y_tr_t)>0) else None


    for ep in range(TFT_EPOCHS):
        model.train(); tr_loss=0.0
        for xb,yb in dl_tr:
            xb,yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            with _amp_ctx():
                pred = model(xb)                 # pred in transformed space
                l = (pred - yb) ** 2
                if q is not None:
                    w = torch.where(
                        yb >= torch.tensor(q, device=yb.device, dtype=yb.dtype),
                        torch.tensor(BARO_SPIKE_LOSS_W, device=yb.device, dtype=yb.dtype),
                        torch.tensor(1.0,               device=yb.device, dtype=yb.dtype)
                    )
                    l = l * w
                loss = l.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            tr_loss += loss.item() * len(xb)
        tr_loss /= max(1, len(dl_tr.dataset))

        model.eval(); va_loss=0.0
        with torch.no_grad(), _amp_ctx():
            for xb,yb in dl_va:
                xb,yb = xb.to(device), yb.to(device)
                out = model(xb)
                va_loss += nn.functional.mse_loss(out, yb, reduction="sum").item()
        va_loss /= max(1, len(dl_va.dataset))
        train_curve.append(tr_loss); val_curve.append(va_loss)
        print(f"[TFT {horizon}m] epoch {ep+1:02d}  trainMSE={tr_loss:.5f}  valMSE={va_loss:.5f}")

        if va_loss + 1e-6 < best:
            best = va_loss; best_state = model.state_dict(); bad = 0
        else:
            bad += 1
            if bad >= TFT_PATIENCE:
                print("[TFT] early stop."); break

    if best_state is not None:
        model.load_state_dict(best_state)


    def predict_tft(model, X, device):
        ds = TensorDataset(torch.from_numpy(X))
        dl = DataLoader(ds, batch_size=min(TFT_BATCH, len(ds)), shuffle=False, drop_last=False)
        out=[]; model.eval()
        with torch.no_grad(), _amp_ctx():
            for (xb,) in dl:
                xb = xb.to(device)
                out.append(model(xb).cpu().numpy())
        return np.concatenate(out, axis=0) if len(out)>0 else np.array([])

    yp_tr = predict_tft(model, X_tr_n, device) if len(X_tr_n)>0 else np.array([])
    yp_va = predict_tft(model, X_va_n, device) if len(X_va_n)>0 else np.array([])
    yp_te = predict_tft(model, X_te_n, device) if len(X_te_n)>0 else np.array([])

    # --- invert transform back to variance units ---
    if TFT_TARGET_LOG1P:
        if yp_tr.size: yp_tr = np.expm1(yp_tr)
        if yp_va.size: yp_va = np.expm1(yp_va)
        if yp_te.size: yp_te = np.expm1(yp_te)

    # Safety clamp (numeric)
    if yp_tr.size: yp_tr = np.maximum(yp_tr, TFT_VAR_FLOOR)
    if yp_va.size: yp_va = np.maximum(yp_va, TFT_VAR_FLOOR)
    if yp_te.size: yp_te = np.maximum(yp_te, TFT_VAR_FLOOR)


    def _rmse(a,b):
        return float(np.sqrt(np.mean((np.asarray(a)-np.asarray(b))**2))) if len(a)==len(b) and len(a)>0 else float("nan")

    R2   = r2_score(y_te, yp_te) if len(yp_te)>0 and len(y_te)>0 else float("nan")
    RMSE = _rmse(y_te, yp_te)
    tr_RMSE = _rmse(y_tr, yp_tr) if len(yp_tr)>0 else float("nan")
    va_RMSE = _rmse(y_va, yp_va) if len(yp_va)>0 else float("nan")
    print(f"[TFT] {horizon}m  Train RMSE={tr_RMSE:.3f}  Val RMSE={va_RMSE:.3f}  Test R²={R2:.3f}  RMSE={RMSE:.3f}")

    if np.isfinite(tr_RMSE) and np.isfinite(va_RMSE):
        gap = va_RMSE - tr_RMSE
        if gap > max(0.15*va_RMSE, 0.02):
            print("  ↪ Likely OVERFITTING: try ↑TFT_DROPOUT, ↓TFT_D_MODEL/FF_DIM, fewer TFT_EPOCHS, ↑HISTORY_HOURS.")
        elif (tr_RMSE > va_RMSE*1.2):
            print("  ↪ Likely UNDERFITTING: try ↑TFT_D_MODEL/FF_DIM/LSTM_LAYERS/EPOCHS, add features, extend HISTORY_HOURS.")

    results[horizon] = dict(
        t_val=t_va, y_val=y_va, yp_val=yp_va,
        t=t_te, y=y_te, yp=yp_te, r2=R2, rmse=RMSE,
        scaler=rs, train_curve=np.array(train_curve), val_curve=np.array(val_curve),
        t_tr=t_tr, y_tr=y_tr, yp_tr=yp_tr
    )

    # Learning curves
    if len(train_curve) and len(val_curve):
        plt.figure(figsize=(8,3.2))
        plt.plot(train_curve, label="Train MSE")
        plt.plot(val_curve, label="Val MSE")
        plt.title(f"TFT learning curves — {horizon} min")
        plt.xlabel("Epoch"); plt.ylabel("MSE"); plt.grid(alpha=0.3); plt.legend(); plt.tight_layout()
        plt.show()

# ---- Per-horizon visualization: TFT vs actual + inputs during forecast window ----
def _get_base_series(name:str):
    if name in df_all.columns:
        return df_all["time"].to_numpy(), df_all[name].to_numpy()
    return None, None

for horizon in HORIZONS_MIN:
    rec = results.get(horizon)
    if not rec:
        continue
    if PRED_SPLIT.lower() == "test" and len(rec["t"])>0:
        t_plot, y_plot, yhat_plot = rec["t"], rec["y"], rec["yp"]
        split_name = "Test"
    else:
        t_plot, y_plot, yhat_plot = rec["t_val"], rec["y_val"], rec["yp_val"]
        split_name = "Validation"
    if len(t_plot)==0:
        continue

    mwin = (t_plot >= PRED_START) & (t_plot <= PRED_END)
    if not np.any(mwin):
        end = t_plot[-1]; start = end - pd.Timedelta(days=PLOT_LAST_DAYS)
        mwin = (t_plot >= start) & (t_plot <= end)
    if not np.any(mwin):
        continue

    ts, yt, ypp = t_plot[mwin], y_plot[mwin], yhat_plot[mwin]

    # Figure 1: TFT predictions vs actual
    plt.figure(figsize=(14,4))
    plt.plot(ts, yt, label="Measured var1h", lw=1.2)
    plt.plot(ts, ypp, "--", label=f"TFT Pred ({split_name})", lw=1.2)
    plt.title(f"Mini-TFT — {horizon} min ({split_name})")
    plt.ylabel("var1h"); plt.xlabel("Time"); plt.grid(alpha=0.3); plt.legend()
    plt.tight_layout(); plt.show()

    # Figure 2: Input drivers (top K bases)
    bases_to_plot = selected_bases[:PLOT_TOP_INPUTS]
    nK = len(bases_to_plot)
    if nK > 0:
        fig, axs = plt.subplots(nK, 1, figsize=(14, 1.8*nK), sharex=True)
        if nK == 1: axs = [axs]
        for ax, bn in zip(axs, bases_to_plot):
            tb, xb = _get_base_series(bn)
            if tb is None:
                ax.text(0.02, 0.6, f"{bn} (no series)", transform=ax.transAxes); continue
            tb = pd.to_datetime(tb)
            m = (tb >= ts.min()) & (tb <= ts.max())
            ax.plot(tb[m], np.asarray(xb)[m], lw=1.0)
            ax.set_ylabel(bn); ax.grid(alpha=0.25)
        axs[-1].set_xlabel("Time")
        fig.suptitle(f"Driver inputs during forecast window — horizon {horizon} min", y=0.98)
        plt.tight_layout(rect=[0,0,1,0.96]); plt.show()

print("\nFinal selected lag channels (one per base):")
for nm in selected_lag_cols:
    print("  ", nm)


⚠️ Patch failed: 'ipynb'


STDMET years:   0%|          | 0/18 [00:00<?, ?yr/s]

KeyboardInterrupt: 

In [2]:
# TEMP (Colab): ensure metadata.widgets.state exists so GitHub can render
# Run once before "Save a copy in GitHub…", then comment out.
import json, os, time
patched_path = "/content/notebook_patched.ipynb"

def _ensure_state(nb):
    nb.setdefault("metadata", {}).setdefault("widgets", {}).setdefault("state", {})
    return nb

try:
    # 1) Ask Colab for the live notebook JSON
    from google.colab import _message, notebook as colab_notebook
    # Force-save first so frontend state is current
    try:
        colab_notebook.save("")
        time.sleep(0.5)
    except Exception:
        pass

    resp = _message.blocking_request("get_ipynb", {})  # typical key: 'ipynb'
    nb = resp.get("ipynb") or resp.get("notebook") or resp  # fallback for edge cases
    if not isinstance(nb, dict):
        raise RuntimeError("Colab did not return a notebook dict")

    _ensure_state(nb)

    # 2) Try to push patched JSON back into the running doc (so the GitHub saver uses it)
    _message.blocking_request("set_ipynb", {"ipynb": nb})
    print("✅ Patched in-memory notebook metadata. Now do: File → Save a copy in GitHub…")

    # 3) Also write a patched copy to disk as a fallback
    with open(patched_path, "w") as f:
        json.dump(nb, f, ensure_ascii=False)
    print(f"💾 Fallback copy written to {patched_path}")

except Exception as e:
    # If Colab API is locked down, at least give you a patched file to upload manually
    print("⚠️ Could not patch in-memory notebook. Writing a patched local copy instead.")
    try:
        # Minimal skeleton with just the metadata fix if we can't read live JSON
        # (still preserves GitHub rendering; cells won't be included in this extreme fallback)
        nb = {"cells": [], "metadata": {"widgets": {"state": {}}}, "nbformat": 4, "nbformat_minor": 5}
        with open(patched_path, "w") as f:
            json.dump(nb, f, ensure_ascii=False)
        print(f"💾 Wrote {patched_path}. If needed, upload this file to GitHub manually.")
    except Exception as ee:
        print("❌ Patch failed:", ee)


⚠️ Could not patch in-memory notebook. Writing a patched local copy instead.
💾 Wrote /content/notebook_patched.ipynb. If needed, upload this file to GitHub manually.
