# Air quality Cross-Station Transfer Learning using geo aware approach 

Protocol:
- Train: first **365 days**
- Val: next **30 days**
- Test: remainder
- Reduced targets: **last 6 months of the 1-year training interval** (val/test unchanged)

Methods:
- Scratch_Univar
- TL_Univar
- Hybrid_TL_Univar
- Scratch_Multivar
- TL_Multivar
- Hybrid_TL_Multivar

In [None]:
import os, glob, time, json, re, random, warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from tqdm.auto import tqdm

import mlflow
from dotenv import load_dotenv

warnings.filterwarnings("ignore")

# -----------------------
# Env + MLflow
# -----------------------
load_dotenv()
mlflow.set_tracking_uri("https://mlflow.stack.grega.xyz/")
EXPERIMENT_NAME = "AirQuality_CrossStation_TrainMonthsSweep_TL_Torch"
mlflow.set_experiment(EXPERIMENT_NAME)

# -----------------------
# Repro
# -----------------------
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -----------------------
# Config
# -----------------------
DATA_DIR = "data/processed/combined"
GEO_JSON_PATH = "data/processed/stations_geo.json"   # change if needed

TRAIN_DAYS = 365
VAL_DAYS = 30
TRAIN_MONTHS_SWEEP = list(range(1, 13))

SEQ_LEN = 24
HORIZON = 1

BATCH_SIZE = 64
MAX_EPOCHS = 200
LR = 1e-3
WEIGHT_DECAY = 0.0

EARLY_STOP_PATIENCE = 10
REDUCE_LR_PATIENCE = 5

LSTM_UNITS = 48
EMBED_DIM = 16
GEO_DIM = 2

RESULTS_DIR = "exp_results"
os.makedirs(RESULTS_DIR, exist_ok=True)

DEFAULT_METEO_NUM_COLS = ["temperature", "rain", "pressure", "precipitation", "wind_speed"]
CAL_COLS = ["hour_sin", "hour_cos", "dow_sin", "dow_cos", "doy_sin", "doy_cos"]


# ============================================================
# Helpers: geo loader
# ============================================================
def load_geo_map(path: str) -> Dict[str, Dict[str, float]]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Geo file not found: {path}")
    txt = open(path, "r", encoding="utf-8").read().strip()
    try:
        geo_list = json.loads(txt)
    except json.JSONDecodeError:
        txt2 = re.sub(r'([\{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*):', r'\1"\2"\3:', txt)
        txt2 = re.sub(r"'", '"', txt2)
        txt2 = re.sub(r",\s*([}\]])", r"\1", txt2)
        geo_list = json.loads(txt2)

    geo = {}
    for s in geo_list:
        serial = s.get("serial")
        lat = s.get("latitude")
        lon = s.get("longitude")
        label = s.get("label", serial)
        if serial is None or lat is None or lon is None:
            continue
        geo[str(serial)] = {"lat": float(lat), "lon": float(lon), "label": str(label)}
    if not geo:
        raise ValueError("No valid geo stations found.")
    return geo

geo_map = load_geo_map(GEO_JSON_PATH)

# ============================================================
# Helpers: load + features + leakage-free split/impute
# ============================================================
def list_station_files(data_dir: str) -> List[str]:
    files = sorted(glob.glob(os.path.join(data_dir, "*.csv")))
    if not files:
        raise FileNotFoundError(f"No CSVs in {data_dir}")
    return files

def station_name_from_path(p: str) -> str:
    return os.path.splitext(os.path.basename(p))[0]

def load_station_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    df["datetime"] = pd.to_datetime(df["datetime"])
    df = df.sort_values("datetime").reset_index(drop=True)
    return df

def add_calendar_features(df: pd.DataFrame) -> pd.DataFrame:
    d = df.copy()
    dt = d["datetime"]
    hour = dt.dt.hour.values
    dow  = dt.dt.dayofweek.values
    doy  = dt.dt.dayofyear.values
    d["hour_sin"] = np.sin(2*np.pi*hour/24.0)
    d["hour_cos"] = np.cos(2*np.pi*hour/24.0)
    d["dow_sin"]  = np.sin(2*np.pi*dow/7.0)
    d["dow_cos"]  = np.cos(2*np.pi*dow/7.0)
    d["doy_sin"]  = np.sin(2*np.pi*doy/365.25)
    d["doy_cos"]  = np.cos(2*np.pi*doy/365.25)
    return d

def split_by_time_spans(df: pd.DataFrame, train_days: int, val_days: int):
    start = df["datetime"].min()
    train_end = start + pd.Timedelta(days=train_days)
    val_end = train_end + pd.Timedelta(days=val_days)
    tr = df[df["datetime"] < train_end].copy()
    va = df[(df["datetime"] >= train_end) & (df["datetime"] < val_end)].copy()
    te = df[df["datetime"] >= val_end].copy()
    # Basic length checks (prevents empty window tensors)
    min_train_points = (SEQ_LEN + HORIZON) + 24  # at least ~1 day beyond a single window
    min_val_points = (SEQ_LEN + HORIZON) + 24
    if len(tr) < min_train_points:
        raise ValueError(f"Train split too short: {len(tr)} rows")
    if len(va) < min_val_points:
        raise ValueError(f"Val split too short: {len(va)} rows")
    if len(te) < (SEQ_LEN + HORIZON):
        raise ValueError(f"Test split too short: {len(te)} rows")
    return tr.reset_index(drop=True), va.reset_index(drop=True), te.reset_index(drop=True)

def fill_missing_time_split_local(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
    d = df.copy().set_index("datetime")
    d[cols] = d[cols].interpolate(method="time", limit_direction="both")
    d[cols] = d[cols].ffill().bfill()
    return d.reset_index()

def reduce_train_last_months(train_df: pd.DataFrame, months: int = 6) -> pd.DataFrame:
    end_excl = train_df["datetime"].max() + pd.Timedelta(hours=1)
    start = end_excl - pd.DateOffset(months=months)
    red = train_df[train_df["datetime"] >= start].copy()
    return red.reset_index(drop=True)

def dyn_cols_meteo_cal(meteo_cols: List[str]) -> List[str]:
    return ["PM10"] + meteo_cols + CAL_COLS

def prepare_station_record(path: str) -> Dict[str, object]:
    st = station_name_from_path(path)
    df = load_station_csv(path)
    df = add_calendar_features(df)
    meteo_cols = [c for c in DEFAULT_METEO_NUM_COLS if c in df.columns]

    # leakage-safe: split first, then impute within each split
    tr, va, te = split_by_time_spans(df, TRAIN_DAYS, VAL_DAYS)
    cont_cols = ["PM10"] + (["PM2.5"] if "PM2.5" in df.columns else []) + meteo_cols
    tr = fill_missing_time_split_local(tr, cont_cols)
    va = fill_missing_time_split_local(va, cont_cols)
    te = fill_missing_time_split_local(te, cont_cols)

    return {"station": st, "df_train": tr, "df_val": va, "df_test": te, "meteo_cols": meteo_cols}

station_files = list_station_files(DATA_DIR)
stations = [station_name_from_path(p) for p in station_files]
station_to_id = {s:i for i,s in enumerate(stations)}
num_stations = len(stations)

station_data = [prepare_station_record(p) for p in station_files]

# common meteo columns across all stations
common = sorted(list(set.intersection(*[set(d["meteo_cols"]) for d in station_data])))
X_COLS = dyn_cols_meteo_cal(common)

# Univariate feature set (PM10-only)
X_COLS_UNIV = ["PM10"]

# Multivariate feature set
X_COLS_MULT = X_COLS

METHODS = ["Scratch", "TL", "GeoAware_TL"]

print("Stations:", len(stations), "Common meteo:", common)

_missing_geo_logged = set()

def geo_vec(station: str) -> np.ndarray:
    g = geo_map.get(station)
    if g is None:
        if station not in _missing_geo_logged:
            warnings.warn(f"Missing geo for station={station}; using zeros")
            _missing_geo_logged.add(station)
        return np.zeros((2,), dtype=np.float32)
    return np.array([g["lat"], g["lon"]], dtype=np.float32)

geo_vec_map = {s: geo_vec(s) for s in stations}


# ============================================================
# Scaling + windowing
# ============================================================
@dataclass
class Scalers:
    x_mean: np.ndarray
    x_std: np.ndarray
    y_mean: float
    y_std: float

def fit_scalers(train_df: pd.DataFrame, x_cols: List[str], y_col: str = "PM10") -> Scalers:
    X = train_df[x_cols].values.astype(np.float32)
    y = train_df[y_col].values.astype(np.float32)
    return Scalers(
        x_mean=X.mean(axis=0),
        x_std=X.std(axis=0) + 1e-8,
        y_mean=float(y.mean()),
        y_std=float(y.std() + 1e-8),
    )

def apply_scalers(df: pd.DataFrame, sc: Scalers, x_cols: List[str], y_col: str = "PM10") -> pd.DataFrame:
    d = df.copy()
    d[x_cols] = (d[x_cols].values.astype(np.float32) - sc.x_mean) / sc.x_std
    d[f"{y_col}_y"] = (d[y_col].values.astype(np.float32) - sc.y_mean) / sc.y_std
    return d

def make_windows(df: pd.DataFrame, x_cols: List[str], y_scaled_col: str, seq_len: int, horizon: int):
    Xv = df[x_cols].values.astype(np.float32)
    yv = df[y_scaled_col].values.astype(np.float32)
    X, Y = [], []
    for end in range(seq_len, len(df) - horizon + 1):
        X.append(Xv[end-seq_len:end])
        Y.append(yv[end + horizon - 1])
    return np.stack(X, axis=0), np.array(Y, dtype=np.float32).reshape(-1, 1)

def build_station_windows(station_record, x_cols: List[str], train_df_override: Optional[pd.DataFrame] = None):
    tr = train_df_override if train_df_override is not None else station_record["df_train"]
    va, te = station_record["df_val"], station_record["df_test"]

    sc = fit_scalers(tr, x_cols, "PM10")
    tr_s = apply_scalers(tr, sc, x_cols, "PM10")
    va_s = apply_scalers(va, sc, x_cols, "PM10")
    te_s = apply_scalers(te, sc, x_cols, "PM10")

    Xtr, ytr = make_windows(tr_s, x_cols, "PM10_y", SEQ_LEN, HORIZON)
    Xva, yva = make_windows(va_s, x_cols, "PM10_y", SEQ_LEN, HORIZON)
    Xte, yte = make_windows(te_s, x_cols, "PM10_y", SEQ_LEN, HORIZON)
    tte = te_s["datetime"].iloc[SEQ_LEN + HORIZON - 1:].reset_index(drop=True)
    return (Xtr, ytr), (Xva, yva), (Xte, yte), tte, sc

def inverse_y(y_scaled: np.ndarray, sc: Scalers) -> np.ndarray:
    return y_scaled * sc.y_std + sc.y_mean

def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)
    err = y_true - y_pred
    mae = float(np.mean(np.abs(err)))
    mse = float(np.mean(err ** 2))
    rmse = float(np.sqrt(mse))
    denom = np.clip(np.abs(y_true), 1e-8, None)
    mape = float(np.mean(np.abs(err) / denom) * 100.0)
    smape = float(np.mean(2.0 * np.abs(err) / np.clip(np.abs(y_true) + np.abs(y_pred), 1e-8, None)) * 100.0)
    y_var = float(np.var(y_true))
    r2 = float(1.0 - (np.sum(err ** 2) / np.clip(np.sum((y_true - np.mean(y_true)) ** 2), 1e-8, None)))
    evs = float(1.0 - (np.var(err) / np.clip(y_var, 1e-8, None)))
    return {"MAE": mae, "MSE": mse, "RMSE": rmse, "MAPE": mape, "sMAPE": smape, "R2": r2, "EVS": evs}


# ============================================================
# Models
# ============================================================
class LSTMRegressor(nn.Module):
    def __init__(self, n_features: int, hidden: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size=n_features, hidden_size=hidden, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        h = out[:, -1, :]
        return self.fc(h)

class HybridLocLSTM(nn.Module):
    """
    dyn_seq: [B,T,F]
    station_id: [B] long
    geo: [B,2] float
    """
    def __init__(self, n_dyn_features: int, hidden: int, num_stations: int, embed_dim: int, geo_dim: int = 2):
        super().__init__()
        self.station_emb = nn.Embedding(num_stations, embed_dim)
        self.geo_mlp = nn.Sequential(
            nn.Linear(geo_dim, 16), nn.ReLU(),
            nn.Linear(16, embed_dim), nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=n_dyn_features + embed_dim, hidden_size=hidden, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden, 1)

    def forward(self, dyn_seq, station_id, geo):
        e_id = self.station_emb(station_id)          # [B,E]
        e_geo = self.geo_mlp(geo)                    # [B,E]
        e = e_id + e_geo                             # [B,E]
        e_seq = e.unsqueeze(1).expand(-1, dyn_seq.size(1), -1)  # [B,T,E]
        x = torch.cat([dyn_seq, e_seq], dim=-1)      # [B,T,F+E]
        out, _ = self.lstm(x)
        h = out[:, -1, :]
        return self.fc(h)

def freeze_lstm_layer0(model: nn.Module):
    for name, p in model.lstm.named_parameters():
        if "l0" in name:
            p.requires_grad = False
    return model

def build_backbone_model(n_features: int) -> LSTMRegressor:
    return LSTMRegressor(n_features=n_features, hidden=LSTM_UNITS).to(device)

def build_hybrid_model(n_dyn_features: int) -> HybridLocLSTM:
    return HybridLocLSTM(
        n_dyn_features=n_dyn_features,
        hidden=LSTM_UNITS,
        num_stations=num_stations,
        embed_dim=EMBED_DIM,
        geo_dim=GEO_DIM,
    ).to(device)

def warm_start_station_embedding_from_source(model: HybridLocLSTM, source_station: str, target_station: str):
    with torch.no_grad():
        sid = station_to_id[source_station]
        tid = station_to_id[target_station]
        model.station_emb.weight[tid].copy_(model.station_emb.weight[sid])
    return model


# ============================================================
# Training utilities (early stopping + reduce LR)
# ============================================================
def make_loader(X, y, batch_size, shuffle):
    ds = TensorDataset(torch.from_numpy(X), torch.from_numpy(y))
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=False)

def train_loop(model, train_loader, val_loader, optimizer, scheduler, max_epochs, patience, run_name: str,
               hybrid: bool = False, station_id=None, geo=None):
    """
    hybrid=False expects batch: (X,y)
    hybrid=True expects dyn_seq from X, but station_id and geo are provided as full tensors aligned with X
    """
    best_val = float("inf")
    bad = 0
    best_state = None
    t0 = time.time()

    loss_fn = nn.MSELoss()

    for epoch in range(1, max_epochs+1):
        model.train()
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            optimizer.zero_grad(set_to_none=True)
            if not hybrid:
                pred = model(xb)
            else:
                # station_id, geo must be tensors aligned to dataset order
                # We reconstruct by indexing the full tensors using batch indices is cumbersome with DataLoader,
                # so for hybrid we pass loaders that already include station_id and geo.
                raise RuntimeError("Use train_loop_hybrid for hybrid models.")
            loss = loss_fn(pred, yb)
            loss.backward()
            optimizer.step()

        # val
        model.eval()
        vloss = 0.0
        n = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                pred = model(xb)
                l = loss_fn(pred, yb).item()
                vloss += l * len(xb)
                n += len(xb)
        vloss /= max(1, n)

        if scheduler is not None:
            scheduler.step(vloss)

        if vloss + 1e-12 < best_val:
            best_val = vloss
            bad = 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

    return time.time() - t0, best_val

def make_hybrid_loader(X, y, station_ids, geos, batch_size, shuffle):
    ds = TensorDataset(
        torch.from_numpy(X),
        torch.from_numpy(y),
        torch.from_numpy(station_ids.astype(np.int64)),
        torch.from_numpy(geos.astype(np.float32)),
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=False)

def train_loop_hybrid(model, train_loader, val_loader, optimizer, scheduler, max_epochs, patience, run_name: str):
    best_val = float("inf")
    bad = 0
    best_state = None
    t0 = time.time()
    loss_fn = nn.MSELoss()

    for epoch in range(1, max_epochs+1):
        model.train()
        for xb, yb, sidb, geob in train_loader:
            xb = xb.to(device); yb = yb.to(device)
            sidb = sidb.to(device); geob = geob.to(device)
            optimizer.zero_grad(set_to_none=True)
            pred = model(xb, sidb, geob)
            loss = loss_fn(pred, yb)
            loss.backward()
            optimizer.step()

        # val
        model.eval()
        vloss = 0.0
        n = 0
        with torch.no_grad():
            for xb, yb, sidb, geob in val_loader:
                xb = xb.to(device); yb = yb.to(device)
                sidb = sidb.to(device); geob = geob.to(device)
                pred = model(xb, sidb, geob)
                l = loss_fn(pred, yb).item()
                vloss += l * len(xb)
                n += len(xb)
        vloss /= max(1, n)

        if scheduler is not None:
            scheduler.step(vloss)

        if vloss + 1e-12 < best_val:
            best_val = vloss
            bad = 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

    return time.time() - t0, best_val

@torch.no_grad()
def predict(model, loader, hybrid: bool = False):
    model.eval()
    preds = []
    if not hybrid:
        for xb, yb in loader:
            xb = xb.to(device)
            pred = model(xb).detach().cpu().numpy()
            preds.append(pred)
    else:
        for xb, yb, sidb, geob in loader:
            xb = xb.to(device); sidb = sidb.to(device); geob = geob.to(device)
            pred = model(xb, sidb, geob).detach().cpu().numpy()
            preds.append(pred)
    return np.vstack(preds)

def log_torch_model(model: nn.Module, artifact_path: str, fname: str = "model.pt"):
    os.makedirs(artifact_path, exist_ok=True)
    path = os.path.join(artifact_path, fname)
    torch.save(model.state_dict(), path)
    mlflow.log_artifact(path, artifact_path=os.path.basename(artifact_path))


# ============================================================
# Core: source training + target eval
# ============================================================
def train_source_models(source_record):
    src = source_record["station"]
    pack = {"src_station": src}

    (Xtr, ytr), (Xva, yva), _, _, _ = build_station_windows(source_record, X_COLS_MULT)

    m_backbone = build_backbone_model(n_features=Xtr.shape[-1])
    opt = torch.optim.Adam(m_backbone.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=REDUCE_LR_PATIENCE, min_lr=1e-6)
    tr_loader = make_loader(Xtr, ytr, BATCH_SIZE, shuffle=True)
    va_loader = make_loader(Xva, yva, BATCH_SIZE, shuffle=False)
    dt_backbone, _ = train_loop(
        m_backbone, tr_loader, va_loader, opt, sch,
        MAX_EPOCHS, EARLY_STOP_PATIENCE, run_name=f"src_backbone_{src}"
    )

    sid_tr = np.full((len(Xtr),), station_to_id[src], dtype=np.int64)
    sid_va = np.full((len(Xva),), station_to_id[src], dtype=np.int64)
    geo = geo_vec_map[src].astype(np.float32)
    geo_tr = np.tile(geo.reshape(1, -1), (len(Xtr), 1))
    geo_va = np.tile(geo.reshape(1, -1), (len(Xva), 1))

    m_hybrid = build_hybrid_model(n_dyn_features=Xtr.shape[-1])
    opt2 = torch.optim.Adam(m_hybrid.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sch2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt2, factor=0.5, patience=REDUCE_LR_PATIENCE, min_lr=1e-6)
    tr_loader2 = make_hybrid_loader(Xtr, ytr, sid_tr, geo_tr, BATCH_SIZE, shuffle=True)
    va_loader2 = make_hybrid_loader(Xva, yva, sid_va, geo_va, BATCH_SIZE, shuffle=False)
    dt_hybrid, _ = train_loop_hybrid(
        m_hybrid, tr_loader2, va_loader2, opt2, sch2,
        MAX_EPOCHS, EARLY_STOP_PATIENCE, run_name=f"src_hybrid_{src}"
    )

    pack.update({
        "backbone_model": m_backbone,
        "backbone_train_time": dt_backbone,
        "hybrid_model": m_hybrid,
        "hybrid_train_time": dt_hybrid,
    })
    return pack

def log_common_params(source_station: str, target_station: str, method: str,
                      x_cols: List[str], train_windows_full: int, train_windows_used: int,
                      train_months_used: int, target_col: str, target_name: str):
    mlflow.log_params({
        "source_station": source_station,
        "target_station": target_station,
        "target_name": target_name,
        "target_col": target_col,
        "method": method,
        "feature_set": "univariate" if x_cols == ["PM10"] else "multivariate",
        "features": ",".join(x_cols),

        "train_days": TRAIN_DAYS,
        "val_days": VAL_DAYS,
        "train_months_used": int(train_months_used),
        "reduction_mode": "last_months",

        "seq_len": SEQ_LEN,
        "horizon": HORIZON,

        "lstm_units": LSTM_UNITS,
        "embed_dim": EMBED_DIM,
        "lr": LR,
        "batch_size": BATCH_SIZE,
        "max_epochs": MAX_EPOCHS,

        "train_windows_full": train_windows_full,
        "train_windows_used": train_windows_used,

        "device": str(device),

        "lat": float(geo_vec_map[target_station][0]),
        "lon": float(geo_vec_map[target_station][1]),
    })

def log_predictions_artifact(
    pred_df: pd.DataFrame,
    source_station: str,
    target_station: str,
    method: str,
    train_months_used: int,
    artifact_root: str = "predictions",
):
    tmp_dir = os.path.join(RESULTS_DIR, "tmp_predictions")
    os.makedirs(tmp_dir, exist_ok=True)
    method_safe = method.replace("/", "_")
    fname = f"pred_{source_station}_to_{target_station}_{method_safe}_k{train_months_used}.csv"
    path = os.path.join(tmp_dir, fname)
    pred_df.to_csv(path, index=False)
    mlflow.log_artifact(path, artifact_path=f"{artifact_root}/k={train_months_used}/{method_safe}")
    return path

def run_one_method(
    method: str,
    source_station: str,
    src_pack: dict,
    target_record: dict,
    target_name: str,
    target_col: str,
    train_months_used: int,
    x_cols: List[str],
    train_windows_full: int,
    train_windows_used: int,
    window_pack: dict,
):
    tgt = target_record["station"]
    Xtr, ytr = window_pack["Xtr"], window_pack["ytr"]
    Xva, yva = window_pack["Xva"], window_pack["yva"]
    Xte, yte = window_pack["Xte"], window_pack["yte"]
    tte, sc, yt = window_pack["tte"], window_pack["sc"], window_pack["yt"]

    run_name = f"method={method}_k={train_months_used}"
    with mlflow.start_run(run_name=run_name, nested=True):
        log_common_params(
            source_station=source_station,
            target_station=tgt,
            method=method,
            x_cols=x_cols,
            train_windows_full=train_windows_full,
            train_windows_used=train_windows_used,
            train_months_used=train_months_used,
            target_col=target_col,
            target_name=target_name,
        )

        if method == "Scratch":
            model = build_backbone_model(n_features=Xtr.shape[-1])
            opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
            sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=REDUCE_LR_PATIENCE, min_lr=1e-6)
            tr_loader = make_loader(Xtr, ytr, BATCH_SIZE, shuffle=True)
            va_loader = make_loader(Xva, yva, BATCH_SIZE, shuffle=False)
            te_loader = make_loader(Xte, yte, BATCH_SIZE, shuffle=False)
            dt, _ = train_loop(
                model, tr_loader, va_loader, opt, sch,
                MAX_EPOCHS, EARLY_STOP_PATIENCE, run_name=f"scratch_{tgt}_k{train_months_used}"
            )
            yp = inverse_y(predict(model, te_loader, hybrid=False), sc)

        elif method == "TL":
            model = build_backbone_model(n_features=Xtr.shape[-1])
            model.load_state_dict(src_pack["backbone_model"].state_dict())
            freeze_lstm_layer0(model)
            opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
            sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=REDUCE_LR_PATIENCE, min_lr=1e-6)
            tr_loader = make_loader(Xtr, ytr, BATCH_SIZE, shuffle=True)
            va_loader = make_loader(Xva, yva, BATCH_SIZE, shuffle=False)
            te_loader = make_loader(Xte, yte, BATCH_SIZE, shuffle=False)
            dt, _ = train_loop(
                model, tr_loader, va_loader, opt, sch,
                MAX_EPOCHS, EARLY_STOP_PATIENCE, run_name=f"tl_{tgt}_k{train_months_used}"
            )
            yp = inverse_y(predict(model, te_loader, hybrid=False), sc)

        elif method == "GeoAware_TL":
            model = build_hybrid_model(n_dyn_features=Xtr.shape[-1])
            model.load_state_dict(src_pack["hybrid_model"].state_dict())
            if tgt != source_station:
                warm_start_station_embedding_from_source(model, source_station, tgt)
            freeze_lstm_layer0(model)
            mlflow.log_param("warm_start_embedding", True)

            sid_tr = np.full((len(Xtr),), station_to_id[tgt], dtype=np.int64)
            sid_va = np.full((len(Xva),), station_to_id[tgt], dtype=np.int64)
            sid_te = np.full((len(Xte),), station_to_id[tgt], dtype=np.int64)
            geo = geo_vec_map[tgt].astype(np.float32)
            geo_tr = np.tile(geo.reshape(1, -1), (len(Xtr), 1))
            geo_va = np.tile(geo.reshape(1, -1), (len(Xva), 1))
            geo_te = np.tile(geo.reshape(1, -1), (len(Xte), 1))

            tr_loader = make_hybrid_loader(Xtr, ytr, sid_tr, geo_tr, BATCH_SIZE, shuffle=True)
            va_loader = make_hybrid_loader(Xva, yva, sid_va, geo_va, BATCH_SIZE, shuffle=False)
            te_loader = make_hybrid_loader(Xte, yte, sid_te, geo_te, BATCH_SIZE, shuffle=False)
            opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
            sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=REDUCE_LR_PATIENCE, min_lr=1e-6)
            dt, _ = train_loop_hybrid(
                model, tr_loader, va_loader, opt, sch,
                MAX_EPOCHS, EARLY_STOP_PATIENCE, run_name=f"geoaware_tl_{tgt}_k{train_months_used}"
            )
            yp = inverse_y(predict(model, te_loader, hybrid=True), sc)
        else:
            raise ValueError(f"Unknown method: {method}")

        m = compute_metrics(yt, yp)
        mlflow.log_metrics({
            "mae": m["MAE"],
            "rmse": m["RMSE"],
            "mse": m["MSE"],
            "mape": m["MAPE"],
            "smape": m["sMAPE"],
            "r2": m["R2"],
            "evs": m["EVS"],
            "train_time_sec": float(dt),
        })

        tmp_dir = os.path.join(RESULTS_DIR, "tmp_models")
        os.makedirs(tmp_dir, exist_ok=True)
        model_path = os.path.join(tmp_dir, f"{source_station}_to_{tgt}_{method}_k{train_months_used}.pt")
        torch.save(model.state_dict(), model_path)
        mlflow.log_artifact(model_path, artifact_path=f"model/k={train_months_used}/{method}")

        pred_df = pd.DataFrame({
            "datetime": pd.to_datetime(tte).astype(str),
            "y_true": yt.reshape(-1),
            "y_pred": yp.reshape(-1),
            "source_station": source_station,
            "target_station": tgt,
            "method": method,
            "train_months_used": int(train_months_used),
            "target_col": target_col,
            "target_name": target_name,
        })
        log_predictions_artifact(
            pred_df=pred_df,
            source_station=source_station,
            target_station=tgt,
            method=method,
            train_months_used=train_months_used,
        )

        return {
            "TargetName": target_name,
            "TargetCol": target_col,
            "SourceStation": source_station,
            "TargetStation": tgt,
            "Method": method,
            "TrainMonthsUsed": int(train_months_used),
            **m,
            "TrainTimeSec": float(dt),
        }

def eval_target_sweep(source_station: str, src_pack: dict, target_record: dict):
    tgt = target_record["station"]
    target_col = "PM10"
    target_name = geo_map.get(tgt, {}).get("label", tgt)

    train_full = target_record["df_train"]
    train_windows_full = max(0, len(train_full) - SEQ_LEN - HORIZON + 1)
    rows = []
    skipped_months = []
    window_cache = {}

    with mlflow.start_run(run_name=f"target={tgt}", nested=True):
        for k in TRAIN_MONTHS_SWEEP:
            cache_key = (tgt, int(k), target_col)
            train_reduced = reduce_train_last_months(train_full, months=int(k))
            train_windows_used = max(0, len(train_reduced) - SEQ_LEN - HORIZON + 1)

            with mlflow.start_run(run_name=f"k={k}", nested=True):
                if cache_key not in window_cache:
                    try:
                        (Xtr, ytr), (Xva, yva), (Xte, yte), tte, sc = build_station_windows(
                            target_record, X_COLS_MULT, train_df_override=train_reduced
                        )
                    except Exception as e:
                        warnings.warn(f"Skipping target={tgt} k={k}: failed windowing ({e})")
                        skipped_months.append(int(k))
                        window_cache[cache_key] = None
                        continue

                    if len(Xtr) == 0 or len(Xva) == 0 or len(Xte) == 0:
                        warnings.warn(f"Skipping target={tgt} k={k}: empty windows (tr={len(Xtr)}, va={len(Xva)}, te={len(Xte)})")
                        skipped_months.append(int(k))
                        window_cache[cache_key] = None
                        continue

                    yt = inverse_y(yte, sc)
                    window_cache[cache_key] = {
                        "Xtr": Xtr, "ytr": ytr,
                        "Xva": Xva, "yva": yva,
                        "Xte": Xte, "yte": yte,
                        "tte": tte, "sc": sc, "yt": yt,
                    }

                pack = window_cache[cache_key]
                if pack is None:
                    continue

                for method in METHODS:
                    rows.append(
                        run_one_method(
                            method=method,
                            source_station=source_station,
                            src_pack=src_pack,
                            target_record=target_record,
                            target_name=target_name,
                            target_col=target_col,
                            train_months_used=int(k),
                            x_cols=X_COLS_MULT,
                            train_windows_full=train_windows_full,
                            train_windows_used=train_windows_used,
                            window_pack=pack,
                        )
                    )

    valid_months = [int(k) for k in TRAIN_MONTHS_SWEEP if int(k) not in set(skipped_months)]
    for k in valid_months:
        method_set = set(r["Method"] for r in rows if int(r["TrainMonthsUsed"]) == int(k))
        if method_set != set(METHODS):
            raise AssertionError(f"Incomplete methods for target={tgt}, k={k}: got={sorted(method_set)} expected={METHODS}")

    return rows

# ============================================================
# Run protocol (MLflow nesting + tqdm)
# ============================================================
all_rows = []

with mlflow.start_run(run_name="cross_station_train_months_sweep_torch_leakage_fixed") as top:
    mlflow.log_params({
        "data_dir": DATA_DIR,
        "geo_json": GEO_JSON_PATH,
        "train_days": TRAIN_DAYS,
        "val_days": VAL_DAYS,
        "train_months_sweep": ",".join(str(k) for k in TRAIN_MONTHS_SWEEP),
        "seq_len": SEQ_LEN,
        "horizon": HORIZON,
        "lstm_units": LSTM_UNITS,
        "embed_dim": EMBED_DIM,
        "lr": LR,
        "batch_size": BATCH_SIZE,
        "max_epochs": MAX_EPOCHS,
        "device": str(device),
        "n_stations": len(stations),
        "leakage_fix": "split_then_impute_per_split",
    })
    config_snapshot = {
        "seed": SEED,
        "train_days": TRAIN_DAYS,
        "val_days": VAL_DAYS,
        "train_months_sweep": TRAIN_MONTHS_SWEEP,
        "seq_len": SEQ_LEN,
        "horizon": HORIZON,
        "batch_size": BATCH_SIZE,
        "max_epochs": MAX_EPOCHS,
        "lr": LR,
        "weight_decay": WEIGHT_DECAY,
        "early_stop_patience": EARLY_STOP_PATIENCE,
        "reduce_lr_patience": REDUCE_LR_PATIENCE,
        "lstm_units": LSTM_UNITS,
        "embed_dim": EMBED_DIM,
        "geo_dim": GEO_DIM,
        "x_cols_mult": X_COLS_MULT,
        "methods": METHODS,
        "device": str(device),
    }
    snap_path = os.path.join(RESULTS_DIR, "config_snapshot.json")
    with open(snap_path, "w", encoding="utf-8") as f:
        json.dump(config_snapshot, f, indent=2)
    mlflow.log_artifact(snap_path, artifact_path="config")

    for source_record in tqdm(station_data, desc="Source station folds"):
        src_station = source_record["station"]

        with mlflow.start_run(run_name=f"source={src_station}", nested=True):
            src_pack = train_source_models(source_record)
            mlflow.log_metrics({
                "src_backbone_train_time_sec": float(src_pack["backbone_train_time"]),
                "src_hybrid_train_time_sec": float(src_pack["hybrid_train_time"]),
            })

            # log source model state dicts
            tmp_dir = os.path.join(RESULTS_DIR, "tmp_models")
            os.makedirs(tmp_dir, exist_ok=True)

            paths = {
                f"{src_station}_source_backbone.pt": src_pack["backbone_model"],
                f"{src_station}_source_hybrid.pt": src_pack["hybrid_model"],
            }
            for fname, model in paths.items():
                p = os.path.join(tmp_dir, fname)
                torch.save(model.state_dict(), p)
                mlflow.log_artifact(p, artifact_path="source_models")

            targets = [d for d in station_data if d["station"] != src_station]
            for target_record in tqdm(targets, desc=f"Targets for {src_station}", leave=False):
                all_rows.extend(eval_target_sweep(src_station, src_pack, target_record))

# Save results
results_df = pd.DataFrame(all_rows)
required_cols = [
    "TargetName", "TargetCol", "SourceStation", "TargetStation", "Method", "TrainMonthsUsed",
    "MAE", "RMSE", "MSE", "MAPE", "sMAPE", "R2", "EVS", "TrainTimeSec"
]
for c in required_cols:
    if c not in results_df.columns:
        results_df[c] = np.nan
results_df = results_df[required_cols]
metrics_path = os.path.join(RESULTS_DIR, "metrics_long.csv")
results_df.to_csv(metrics_path, index=False)
print("Saved:", metrics_path)

# Aggregate artifacts
summary = (
    results_df
    .groupby(["Method", "TrainMonthsUsed"])[["MAE", "RMSE", "MSE", "MAPE", "sMAPE", "R2", "EVS", "TrainTimeSec"]]
    .agg(["mean", "std"])
    .reset_index()
)
summary_path = os.path.join(RESULTS_DIR, "summary_by_method.csv")
summary.to_csv(summary_path, index=False)

pivot = (
    results_df
    .pivot_table(index=["SourceStation", "TargetStation", "TrainMonthsUsed"], columns="Method", values="MAE")
    .reset_index()
)
if "GeoAware_TL" in pivot.columns and "TL" in pivot.columns:
    pivot["GeoAware_minus_TL_MAE"] = pivot["GeoAware_TL"] - pivot["TL"]
if "GeoAware_TL" in pivot.columns and "Scratch" in pivot.columns:
    pivot["GeoAware_minus_Scratch_MAE"] = pivot["GeoAware_TL"] - pivot["Scratch"]
delta_path = os.path.join(RESULTS_DIR, "deltas.csv")
pivot.to_csv(delta_path, index=False)

with mlflow.start_run(run_name="analysis_summary_torch"):
    mlflow.log_artifact(metrics_path, artifact_path="results")
    mlflow.log_artifact(summary_path, artifact_path="results")
    mlflow.log_artifact(delta_path, artifact_path="results")
    mean_mae = results_df.groupby("Method")["MAE"].mean().to_dict()
    mlflow.log_metrics({f"mean_mae_{k}": float(v) for k, v in mean_mae.items()})

results_df.head()
