In [None]:

from __future__ import annotations
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    from torch_geometric.data import HeteroData
    from torch_geometric.nn import HeteroConv, GATv2Conv
except Exception as e:
    raise SystemExit("Please install torch-geometric and matching torch packages.")


# Utility: sMAPE


def smape(y_true, y_pred, eps: float = 1e-8):
    # Accept either torch tensors or numpy arrays
    if isinstance(y_true, torch.Tensor):
        yt = y_true.abs()
        yp = y_pred.abs()
        denom = torch.clamp(yt + yp, min=eps)
        return torch.mean(2.0 * torch.abs(yp - yt) / denom).item()
    else:
        import numpy as np
        yt = np.abs(y_true)
        yp = np.abs(y_pred)
        denom = np.clip(yt + yp, eps, None)
        return float(np.mean(2.0 * np.abs(yp - yt) / denom))


# RCA builders


def _rca_binary(flows: pd.DataFrame, role: str) -> pd.DataFrame:
    """Return binary (RCA>1) matrix: countries x products.
    role ∈ {"export", "import"}
    """
    if role == "export":
        M = flows.pivot_table(index="exporter_id", columns="product_id", values="amount_usd", aggfunc="sum", fill_value=0.0)
    elif role == "import":
        M = flows.pivot_table(index="importer_id", columns="product_id", values="amount_usd", aggfunc="sum", fill_value=0.0)
    else:
        raise ValueError("role must be 'export' or 'import'")

    country_tot = M.sum(axis=1)
    prod_tot = M.sum(axis=0)
    grand = float(prod_tot.sum()) if float(prod_tot.sum()) > 0 else np.nan

    share_country = M.div(country_tot.replace(0, np.nan), axis=0)
    share_global = prod_tot / grand
    rca = share_country.div(share_global.replace(0, np.nan), axis=1)
    rca_bin = (rca > 1.0).astype(np.float32).fillna(0.0)
    return rca_bin


# Graph builder (heterogeneous Country/Product graph)


def build_hetero_graph_from_rca(flows_2024: pd.DataFrame) -> Tuple[HeteroData, Dict[str, int], Dict[str, int], List[str]]:
    """Build HeteroData with node types 'country' and 'product'.
    Relations: ('country','exports_rca','product') and ('country','imports_rca','product').
    Country features: concat [export_RCA_onehot, import_RCA_onehot] over a common product set.
    Product nodes start without x (we'll use a learned embedding in the model).
    Returns: (data, country2idx, product2idx, product_list)
    """
    # calendar 2024 subset already passed in
    rca_exp = _rca_binary(flows_2024, role="export")
    rca_imp = _rca_binary(flows_2024, role="import")

    # align columns (products)
    all_products = sorted(set(rca_exp.columns).union(set(rca_imp.columns)))
    rca_exp = rca_exp.reindex(columns=all_products, fill_value=0.0)
    rca_imp = rca_imp.reindex(columns=all_products, fill_value=0.0)

    # all countries in either matrix
    all_countries = sorted(set(rca_exp.index).union(set(rca_imp.index)))
    rca_exp = rca_exp.reindex(index=all_countries, fill_value=0.0)
    rca_imp = rca_imp.reindex(index=all_countries, fill_value=0.0)

    country2idx = {c: i for i, c in enumerate(all_countries)}
    product2idx = {p: i for i, p in enumerate(all_products)}

    # Country features: concat export-RCA and import-RCA one-hots
    X_country = np.concatenate([rca_exp.values, rca_imp.values], axis=1).astype(np.float32)

    data = HeteroData()
    data["country"].x = torch.tensor(X_country, dtype=torch.float32) # [Nc, 2P]
    data["product"].num_nodes = len(all_products)

    # Edges: exports_rca and imports_rca (undirected - both directions)
    exp_idx = np.argwhere(rca_exp.values >= 1)
    src_c = exp_idx[:, 0]
    dst_p = exp_idx[:, 1]
    data["country", "exports_rca", "product"].edge_index = torch.tensor(
        np.vstack([src_c, dst_p]), dtype=torch.long
    )
    data["product", "exports_rca_rev", "country"].edge_index = torch.tensor(
        np.vstack([dst_p, src_c]), dtype=torch.long
    )

    imp_idx = np.argwhere(rca_imp.values >= 1)
    src_c2 = imp_idx[:, 0]
    dst_p2 = imp_idx[:, 1]
    data["country", "imports_rca", "product"].edge_index = torch.tensor(
        np.vstack([src_c2, dst_p2]), dtype=torch.long
    )
    data["product", "imports_rca_rev", "country"].edge_index = torch.tensor(
        np.vstack([dst_p2, src_c2]), dtype=torch.long
    )

    return data, country2idx, product2idx, all_products


# Dataset building: lags & month features


def _prepare_monthly(flows: pd.DataFrame) -> pd.DataFrame:
    df = flows.copy()
    # make sure'date' is datetime — but don't destroy it if it already is
    if not np.issubdtype(df["date"].dtype, np.datetime64):
        df["date"] = pd.to_datetime(df["date"].astype(str), format="%Y%m", errors="coerce")

    df["ym"] = df["date"].dt.to_period("M").dt.to_timestamp()
    # Aggregate to monthly country–country–product flows
    df = (
        df.groupby(["ym", "exporter_id", "importer_id", "product_id"], as_index=False)
          .agg(amount_usd=("amount_usd", "sum"))
    )
    return df


def build_samples_for_holdout(
    flows: pd.DataFrame,
    holdout_month: str,
    W_months: int,
    lag_K: int = 3,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Return (train_rows, val_rows) for a holdout month.
    - Train rows: feature months t in [T0, Tend], where Tend = hm-2 (targets up to hm-1)
    - Val rows:  feature month t = hm-1  (target = hm)
    include hm in the working slice only so shift(-1) can form the validation target.
    Each row contains: exporter_id, importer_id, product_id, ym (feature month t),
    lag1..lagK, m_sin, m_cos, y_pos (for t+1), y_amt (for t+1)
    """
    dfm = _prepare_monthly(flows)
    print("[HOLDOUT] dfm months:", dfm["ym"].dt.to_period("M").unique())
    hm = pd.Period(holdout_month, freq="M")
    Tval = hm - 1               # feature month for validation (t = hm-1)
    Tend = (hm - 1) - 1         # last training feature month (hm-2)
    T0   = Tend - (W_months - 1)

    # Slice enough history for lags and include hm so next_amount for Tval exists
    start_for_lags = (T0 - lag_K).to_timestamp()
    end_for_lags   = hm.to_timestamp()      # include holdout month for labels only
    sli = dfm[(dfm["ym"] >= start_for_lags) & (dfm["ym"] <= end_for_lags)].copy()

    # Build lags per triple 
    sli = sli.sort_values(["exporter_id", "importer_id", "product_id", "ym"])  # type: ignore
    key = ["exporter_id", "importer_id", "product_id"]
    for k in range(1, lag_K + 1):
        sli[f"lag{k}"] = sli.groupby(key)["amount_usd"].shift(k).fillna(0.0)

    # Next-month labels
    sli["next_amount"] = sli.groupby(key)["amount_usd"].shift(-1)
    sli = sli.dropna(subset=["next_amount"])  # rows with valid t+1

    # Month-of-year features (sin/cos)
    m = sli["ym"].dt.month
    ang = 2 * np.pi * (m - 1) / 12.0
    sli["m_sin"], sli["m_cos"] = np.sin(ang), np.cos(ang)
 
    # Targets
    sli["y_amt"] = sli["next_amount"].astype(np.float32)
    sli["y_pos"] = (sli["y_amt"] > 0).astype(np.float32)

    train_rows = sli[(sli["ym"] >= T0.to_timestamp()) & (sli["ym"] <= Tend.to_timestamp())].copy()
    val_rows   = sli[sli["ym"] == Tval.to_timestamp()].copy()  # features at hm-1 → target hm

    # Safety guards to prevent empty lag/validation issues
    for k in range(1, lag_K + 1):
        if f"lag{k}" not in sli.columns:
            sli[f"lag{k}"] = 0.0

    if train_rows.empty:
        print(f"[WARN] train_rows empty for holdout={holdout_month}; filling dummy 1-row DataFrame.")
        train_rows = sli.head(1).copy()
    if val_rows.empty:
        print(f"[WARN] val_rows empty for holdout={holdout_month}; filling dummy 1-row DataFrame.")
        val_rows = sli.head(1).copy()


    return train_rows, val_rows


# ID mappers & partner selection


def build_id_maps(df: pd.DataFrame) -> Tuple[Dict[str, int], Dict[str, int], Dict[str, int]]:
    countries = sorted(set(df["exporter_id"]).union(set(df["importer_id"])))
    products = sorted(df["product_id"].unique())
    c2i = {c: idx for idx, c in enumerate(countries)}
    p2i = {p: idx for idx, p in enumerate(products)}
    return c2i, c2i, p2i


def latest_full_year_slice(flows: pd.DataFrame) -> tuple[pd.DataFrame, int]:
    """Return a slice for the most recent full calendar year if available (12 months),
    otherwise fall back to the latest year present in the data."""
    s = flows.copy()
    s["y"] = s["date"].dt.year
    s["m"] = s["date"].dt.month
    months_per_year = s.groupby("y")["m"].nunique()
    full_years = months_per_year[months_per_year >= 12].index
    year = int(full_years.max()) if len(full_years) > 0 else int(s["y"].max())
    return flows[flows["date"].dt.year == year], year


def top20_partners_by_year(flows_year: pd.DataFrame, anchor: str, direction: str) -> list[str]:
    """Return top-20 partners for anchor (e.g., 'USA'/'CHN') in a given direction
    using the provided year's slice. Only partners with ≥200 HS4 traded with the anchor
    (either direction) are considered."""
    if direction not in {"import", "export"}:
        raise ValueError("direction must be 'import' or 'export'")

    f = flows_year.copy()
    if direction == "import":
        pool = f[f["importer_id"].eq(anchor)]
        rank = pool.groupby("exporter_id")["amount_usd"].sum()
    else:  # export
        pool = f[f["exporter_id"].eq(anchor)]
        rank = pool.groupby("importer_id")["amount_usd"].sum()

    # unique HS4 with positive trade either direction anchor-partner
    bread = f[(f["exporter_id"].eq(anchor)) | (f["importer_id"].eq(anchor))]
    partner = np.where(bread["exporter_id"].eq(anchor), bread["importer_id"], bread["exporter_id"])
    breadth = bread.assign(partner=partner).groupby("partner")["product_id"].nunique()

    eligible = set(breadth[breadth >= 200].index.tolist())
    ranked = rank[rank.index.isin(eligible)].sort_values(ascending=False)
    return ranked.head(20).index.tolist()



# PyG Dataset for triples


@dataclass
class TripleBatch:
    exp_idx: torch.Tensor
    imp_idx: torch.Tensor
    prod_idx: torch.Tensor
    lags: torch.Tensor
    month_feat: torch.Tensor
    y_pos: torch.Tensor
    y_amt: torch.Tensor

class TriplesDataset(Dataset):
    def __init__(self, rows: pd.DataFrame, country2nid: Dict[str, int], product2nid: Dict[str, int], lag_K: int):
        self.df = rows.reset_index(drop=True)
        self.c2n = country2nid
        self.p2n = product2nid
        self.K = lag_K
        # map IDs
        self.exp_idx = torch.tensor([self.c2n.get(x, -1) for x in self.df["exporter_id"]], dtype=torch.long)
        self.imp_idx = torch.tensor([self.c2n.get(x, -1) for x in self.df["importer_id"]], dtype=torch.long)
        self.prod_idx = torch.tensor([self.p2n.get(x, -1) for x in self.df["product_id"]], dtype=torch.long)
        # filter out rows with missing mapping
        mask = (self.exp_idx >= 0) & (self.imp_idx >= 0) & (self.prod_idx >= 0)
        mask_bool = mask.bool()
        mask_list = mask_bool.tolist()

        self.df = self.df[mask_list].reset_index(drop=True)
        self.exp_idx = self.exp_idx[mask_bool]
        self.imp_idx = self.imp_idx[mask_bool]
        self.prod_idx = self.prod_idx[mask_bool]

        # tensors
        for k in range(1, self.K + 1):
            col = f"lag{k}"
            if col not in self.df.columns:
                print(f"[WARN] Missing {col}; filling with zeros.")
                self.df[col] = 0.0

        lag_cols = [f"lag{k}" for k in range(1, self.K + 1)]
        self.lags = torch.tensor(self.df[lag_cols].values, dtype=torch.float32)

        # month features
        self.month_feat = torch.tensor(self.df[["m_sin", "m_cos"]].values, dtype=torch.float32)
        # labels
        self.y_pos = torch.tensor(self.df["y_pos"].values, dtype=torch.float32)
        self.y_amt = torch.tensor(self.df["y_amt"].values, dtype=torch.float32)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return (self.exp_idx[idx], self.imp_idx[idx], self.prod_idx[idx],
                self.lags[idx], self.month_feat[idx], self.y_pos[idx], self.y_amt[idx])

def collate_triples(batch: List[Tuple[torch.Tensor, ...]]) -> TripleBatch:
    e, i, p, l, m, yp, ya = zip(*batch)
    return TripleBatch(
        exp_idx=torch.stack(e),
        imp_idx=torch.stack(i),
        prod_idx=torch.stack(p),
        lags=torch.stack(l),
        month_feat=torch.stack(m),
        y_pos=torch.stack(yp),
        y_amt=torch.stack(ya),
    )


# Model: Hetero GAT encoder + hurdle heads


class TradeHeteroGAT(nn.Module):
    def __init__(self, in_dim_country: int, n_products: int, d_model: int = 128, gat_hidden: int = 128,
                 gat_heads: int = 4, gat_layers: int = 2, lag_dim: int = 3, time_dim: int = 2, dropout: float = 0.2):
        super().__init__()
        self.proj_country = nn.Linear(in_dim_country, d_model, bias=False)
        self.emb_product = nn.Embedding(n_products, d_model)

        self.convs = nn.ModuleList()
        self.norm_country = nn.ModuleList()
        self.norm_product = nn.ModuleList()

        for _ in range(gat_layers):
            conv = HeteroConv({
                ("country", "exports_rca", "product"): GATv2Conv((d_model, d_model), gat_hidden // gat_heads, heads=gat_heads, dropout=dropout, add_self_loops=False, concat=True),
                ("product", "exports_rca_rev", "country"): GATv2Conv((d_model, d_model), gat_hidden // gat_heads, heads=gat_heads, dropout=dropout, add_self_loops=False, concat=True),
                ("country", "imports_rca", "product"): GATv2Conv((d_model, d_model), gat_hidden // gat_heads, heads=gat_heads, dropout=dropout, add_self_loops=False, concat=True),
                ("product", "imports_rca_rev", "country"): GATv2Conv((d_model, d_model), gat_hidden // gat_heads, heads=gat_heads, dropout=dropout, add_self_loops=False, concat=True),
            }, aggr="sum")
            self.convs.append(conv)
            self.norm_country.append(nn.LayerNorm(gat_hidden))
            self.norm_product.append(nn.LayerNorm(gat_hidden))

        self.dropout = nn.Dropout(dropout)
        self.emb_dim = gat_hidden

        pair_dim = 2 * self.emb_dim + self.emb_dim + lag_dim + time_dim  #exporter + importer + product + lags + month
        hid = 256
        self.head_occ = nn.Sequential(
            nn.Linear(pair_dim, hid), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hid, hid // 2), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hid // 2, 1)
        )
        self.head_amt_log1p = nn.Sequential(
            nn.Linear(pair_dim, hid), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hid, hid // 2), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hid // 2, 1)
        )

    def encode(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        x_dict = {
            "country": self.proj_country(data["country"].x.float()),
            "product": self.emb_product.weight,
        }
        for conv, ln_c, ln_p in zip(self.convs, self.norm_country, self.norm_product):
            x_dict = conv(x_dict, data.edge_index_dict)
            # x_dict contains tensors for both types
            x_dict["country"] = self.dropout(F.elu(ln_c(x_dict["country"])))
            x_dict["product"] = self.dropout(F.elu(ln_p(x_dict["product"])))
        return x_dict

    def forward(self, data: HeteroData, batch: TripleBatch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.encode(data)
        he = x["country"][batch.exp_idx]
        hi = x["country"][batch.imp_idx]
        hp = x["product"][batch.prod_idx]
        feat = torch.cat([he, hi, hp, batch.lags, batch.month_feat], dim=-1)
        logit = self.head_occ(feat).squeeze(-1)
        pred_log1p_amt_pos = self.head_amt_log1p(feat).squeeze(-1)
        # unconditional dollars
        p = torch.sigmoid(logit)
        yhat = p * (torch.exp(pred_log1p_amt_pos).clamp_min(0.0) - 1.0)
        return logit, pred_log1p_amt_pos, yhat


# Training / evaluation loops


def train_epoch(model: TradeHeteroGAT, data: HeteroData, loader: DataLoader, opt: torch.optim.Optimizer,
                device: str, lambda_amt: float = 1.0, pos_weight: torch.Tensor | None = None) -> float:
    model.train()
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    huber = nn.SmoothL1Loss()
    total = 0.0
    for batch in loader:
        batch = TripleBatch(
            exp_idx=batch.exp_idx.to(device),
            imp_idx=batch.imp_idx.to(device),
            prod_idx=batch.prod_idx.to(device),
            lags=batch.lags.to(device),
            month_feat=batch.month_feat.to(device),
            y_pos=batch.y_pos.to(device),
            y_amt=batch.y_amt.to(device),
        )
        opt.zero_grad()
        logit, pred_log1p, yhat = model(data, batch)
        loss_prob = bce(logit, batch.y_pos)
        mask = (batch.y_pos > 0.5)
        if mask.any():
            loss_amt = huber(pred_log1p[mask], torch.log1p(batch.y_amt[mask]))
        else:
            loss_amt = torch.tensor(0.0, device=device)
        loss = loss_prob + lambda_amt * loss_amt
        loss.backward()
        opt.step()
        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def eval_epoch(model: TradeHeteroGAT, data: HeteroData, loader: DataLoader, device: str) -> Dict[str, float]:
    model.eval()
    bce = nn.BCEWithLogitsLoss()
    huber = nn.SmoothL1Loss()
    losses, s_list = [], []
    for batch in loader:
        batch = TripleBatch(
            exp_idx=batch.exp_idx.to(device),
            imp_idx=batch.imp_idx.to(device),
            prod_idx=batch.prod_idx.to(device),
            lags=batch.lags.to(device),
            month_feat=batch.month_feat.to(device),
            y_pos=batch.y_pos.to(device),
            y_amt=batch.y_amt.to(device),
        )
        logit, pred_log1p, yhat = model(data, batch)
        loss_prob = bce(logit, batch.y_pos)
        mask = (batch.y_pos > 0.5)
        if mask.any():
            loss_amt = huber(pred_log1p[mask], torch.log1p(batch.y_amt[mask]))
        else:
            loss_amt = torch.tensor(0.0, device=device)
        loss = loss_prob + loss_amt
        losses.append(loss.item())
        s_list.append(smape(batch.y_amt.detach().cpu(), yhat.detach().cpu()))
    return {"loss": float(np.mean(losses)), "sMAPE": float(np.mean(s_list))}


#Backtest windows & final train


def fit_predict_tradeflows(
    flows: pd.DataFrame,
    windows: Iterable[int] = (12, 18, 24),
    lag_K: int = 3,
    epochs: int = 10,
    batch_size: int = 8192,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    print("=== FUNCTION CHECK (entering fit_predict_tradeflows) ===")

#Ensure proper datetime dtype before inspecting
    flows = flows.copy()
    try:
        if np.issubdtype(flows["date"].dtype, np.number):
            flows["date"] = pd.to_datetime(flows["date"].astype(int).astype(str), format="%Y%m", errors="coerce")
        else:
            flows["date"] = pd.to_datetime(flows["date"], errors="coerce")
    except Exception as e:
        print(f"[WARN] Could not parse 'date' column directly: {e}")
        flows["date"] = pd.to_datetime(flows["date"].astype(str), errors="coerce")

    print("date dtype after coercion:", flows["date"].dtype)
    print("num NaT:", flows["date"].isna().sum())
    if flows["date"].notna().any():
        print("distinct months:", flows["date"].dropna().dt.to_period("M").sort_values().unique())
        print("num distinct months:", len(flows["date"].dropna().dt.to_period("M").unique()))
    else:
        raise ValueError("All 'date' values failed to parse. Check input format.")


    #Parse dates
    flows = flows.copy()
    if np.issubdtype(flows["date"].dtype, np.number):
        flows["date"] = pd.to_datetime(flows["date"].astype(int).astype(str), format="%Y%m")
    else:
        flows["date"] = pd.to_datetime(flows["date"], errors="coerce")

    min_date, max_date = flows["date"].min(), flows["date"].max()
    print(f"[INFO] Building RCA graph dynamically using data from {min_date:%Y-%m} → {max_date:%Y-%m}")

    #Build RCA graph on all available data
    graph, country2nid, product2nid, product_list = build_hetero_graph_from_rca(flows)
    graph = graph.to(device)

    #Backtest window selection
    best = {"W": None, "sMAPE": 1e9, "state": None}
    pos_weight_tensor = None

    for W in windows:
        print(f"\n=== Backtest W={W} months ===")
        # holdout automatically uses October 2024 by convention, but can generalize later
        train_rows, val_rows = build_samples_for_holdout(flows, holdout_month="2024-10", W_months=W, lag_K=lag_K)

        train_ds = TriplesDataset(train_rows, country2nid, product2nid, lag_K)
        val_ds = TriplesDataset(val_rows, country2nid, product2nid, lag_K)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_triples)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_triples)

        in_dim_country = 2 * len(product_list)
        model = TradeHeteroGAT(in_dim_country=in_dim_country, n_products=len(product_list), d_model=128,
                               gat_hidden=128, gat_heads=4, gat_layers=2, lag_dim=lag_K, time_dim=2, dropout=0.2).to(device)
        opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

        pos_rate = float(train_ds.y_pos.mean().item()) if len(train_ds) > 0 else 0.5
        pos_weight_tensor = torch.tensor([(1.0 - pos_rate) / max(pos_rate, 1e-6)], device=device)

        best_val = {"sMAPE": 1e9, "state": None}
        for ep in range(1, epochs + 1):
            tr_loss = train_epoch(model, graph, train_loader, opt, device, lambda_amt=1.0, pos_weight=pos_weight_tensor)
            ev = eval_epoch(model, graph, val_loader, device)
            print(f"ep {ep:02d} | train_loss {tr_loss:.4f} | val_loss {ev['loss']:.4f} | val_sMAPE {ev['sMAPE']:.4f}")
            if ev["sMAPE"] < best_val["sMAPE"]:
                best_val["sMAPE"] = ev["sMAPE"]
                best_val["state"] = {k: v.cpu() for k, v in model.state_dict().items()}

        if best_val["sMAPE"] < best["sMAPE"]:
            best = {"W": W, "sMAPE": best_val["sMAPE"], "state": best_val["state"]}

    print(f"\n>>> Selected W={best['W']} by best sMAPE={best['sMAPE']:.4f}")

    #Dynamic next-month forecast
    all_months = flows["date"].dt.to_period("M").sort_values().unique()
    max_month = all_months.max()
    if len(all_months) < lag_K + 1:
        raise ValueError(f"Not enough months ({len(all_months)}) for lag_K={lag_K}. Need at least {lag_K+1}.")
    latest_holdout = (max_month + 1).strftime("%Y-%m")
    print(f"[INFO] Data covers {all_months.min()} → {max_month}; forecasting {latest_holdout}")

    #Final training
    W_final = int(best["W"])
    train_rows, val_rows = build_samples_for_holdout(flows, holdout_month=latest_holdout, W_months=W_final, lag_K=lag_K)

    train_ds = TriplesDataset(train_rows, country2nid, product2nid, lag_K)
    inf_ds = TriplesDataset(val_rows, country2nid, product2nid, lag_K)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_triples)
    inf_loader = DataLoader(inf_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_triples)

    in_dim_country = 2 * len(product_list)
    final_model = TradeHeteroGAT(in_dim_country=in_dim_country, n_products=len(product_list), d_model=128,
                                 gat_hidden=128, gat_heads=4, gat_layers=2, lag_dim=lag_K, time_dim=2, dropout=0.2).to(device)
    opt = torch.optim.AdamW(final_model.parameters(), lr=3e-4, weight_decay=1e-4)

    pos_rate = float(train_ds.y_pos.mean().item()) if len(train_ds) > 0 else 0.5
    pos_weight_tensor = torch.tensor([(1.0 - pos_rate) / max(pos_rate, 1e-6)], device=device)

    best_state, best_loss = None, 1e9
    for ep in range(1, epochs + 1):
        tr_loss = train_epoch(final_model, graph, train_loader, opt, device, lambda_amt=1.0, pos_weight=pos_weight_tensor)
        print(f"[final] ep {ep:02d} | train_loss {tr_loss:.4f}")
        if tr_loss < best_loss:
            best_loss = tr_loss
            best_state = {k: v.cpu() for k, v in final_model.state_dict().items()}
    if best_state is not None:
        final_model.load_state_dict(best_state)

    # --- Inference ---
    final_model.eval()
    preds = []
    with torch.no_grad():
        for batch in inf_loader:
            batch = TripleBatch(
                exp_idx=batch.exp_idx.to(device),
                imp_idx=batch.imp_idx.to(device),
                prod_idx=batch.prod_idx.to(device),
                lags=batch.lags.to(device),
                month_feat=batch.month_feat.to(device),
                y_pos=batch.y_pos.to(device),
                y_amt=batch.y_amt.to(device),
            )
            _, _, yhat = final_model(graph, batch)
            preds.append(yhat.detach().cpu())

    yhat_t = torch.cat(preds).detach().cpu().view(-1)
    try:
        import numpy as _np  
        yhat_arr = yhat_t.numpy()
    except Exception:
        yhat_arr = yhat_t.tolist()

    inf_df = inf_ds.df.copy()
    inf_df["pred_usd"] = yhat_arr
    inf_df["target_month"] = pd.Period(latest_holdout, freq="M").to_timestamp()

    #submission mask using latest full year
    flows_year, partner_year = latest_full_year_slice(flows)
    us_sources = top20_partners_by_year(flows_year, anchor="USA", direction="import")
    us_dests   = top20_partners_by_year(flows_year, anchor="USA", direction="export")
    cn_sources = top20_partners_by_year(flows_year, anchor="CHN", direction="import")
    cn_dests   = top20_partners_by_year(flows_year, anchor="CHN", direction="export")

    m = (
        (inf_df["importer_id"].eq("USA") & inf_df["exporter_id"].isin(us_sources)) |
        (inf_df["exporter_id"].eq("USA") & inf_df["importer_id"].isin(us_dests))   |
        (inf_df["importer_id"].eq("CHN") & inf_df["exporter_id"].isin(cn_sources)) |
        (inf_df["exporter_id"].eq("CHN") & inf_df["importer_id"].isin(cn_dests))
    )
    submission = inf_df[m].copy()
    submission = submission[["target_month", "exporter_id", "importer_id", "product_id", "pred_usd"]]

    return {
        "selected_W": W_final,
        "graph": graph,
        "model": final_model,
        "partner_year": partner_year,
        "inference_all": inf_df,
        "submission": submission,
    }








In [99]:
import pandas as pd
import numpy as np

def load_trade_file(path, anchor):
    df = pd.read_csv(path)

    # unify province/state column name (not needed later but helps dedup)
    if "province_name" in df.columns:
        df = df.rename(columns={"province_name": "state_name"})

    # standardize column names for the pipeline
    df = df.rename(columns={
        "month_id": "date",
        "trade_value": "amount_usd",
    })

    # direction logic: for Exports, anchor is exporter; for Imports, anchor is importer
    df["exporter_id"] = np.where(df["trade_flow_name"] == "Exports", anchor, df["country_id"])
    df["importer_id"] = np.where(df["trade_flow_name"] == "Exports", df["country_id"], anchor)

    # keep only what you need
    df = df[["date", "exporter_id", "importer_id", "product_id", "amount_usd"]]
    return df


# --- load and combine all four ---
usa_2023 = load_trade_file("trade_s_usa_state_m_hs_2023.csv", "USA")
usa_2024 = load_trade_file("trade_s_usa_state_m_hs_2024.csv", "USA")
usa_2025 = load_trade_file("trade_s_usa_state_m_hs_2025.csv", "USA")
chn_2023 = load_trade_file("trade_s_chn_m_hs_2023.csv", "CHN")
chn_2024 = load_trade_file("trade_s_chn_m_hs_2024.csv", "CHN")
chn_2025 = load_trade_file("trade_s_chn_m_hs_2025.csv", "CHN")

flows = pd.concat([usa_2023, usa_2024, usa_2025, chn_2023, chn_2024, chn_2025], ignore_index=True)

# remove duplicates if both sources contain the same trade flow entries
flows = (
    flows.groupby(["date", "exporter_id", "importer_id", "product_id"], as_index=False)
         .agg(amount_usd=("amount_usd", "sum"))
)

# optional: confirm date coverage
print(flows["date"].min(), flows["date"].max(), flows.shape)


202301 202503 (18811510, 5)


In [None]:
results = fit_predict_tradeflows(
    flows,
    windows=(12, 18, 24),  # backtest 1–2 years of lookback
    lag_K=3,                # use 3-month lag features
    epochs=15,              
    batch_size=8192,
    device="cuda" if torch.cuda.is_available() else "cpu",
)


=== FUNCTION CHECK (entering fit_predict_tradeflows) ===
date dtype after coercion: datetime64[ns]
num NaT: 0
distinct months: <PeriodArray>
['2023-01', '2023-02', '2023-03', '2023-04', '2023-05', '2023-06', '2023-07',
 '2023-08', '2023-09', '2023-10', '2023-11', '2023-12', '2024-01', '2024-02',
 '2024-03', '2024-04', '2024-05', '2024-06', '2024-07', '2024-08', '2024-09',
 '2024-10', '2024-11', '2024-12', '2025-01', '2025-02', '2025-03']
Length: 27, dtype: period[M]
num distinct months: 27
[INFO] Building RCA graph dynamically using data from 2023-01 → 2025-03

=== Backtest W=12 months ===
[HOLDOUT] dfm months: <PeriodArray>
['2023-01', '2023-02', '2023-03', '2023-04', '2023-05', '2023-06', '2023-07',
 '2023-08', '2023-09', '2023-10', '2023-11', '2023-12', '2024-01', '2024-02',
 '2024-03', '2024-04', '2024-05', '2024-06', '2024-07', '2024-08', '2024-09',
 '2024-10', '2024-11', '2024-12', '2025-01', '2025-02', '2025-03']
Length: 27, dtype: period[M]
ep 01 | train_loss 704.3120 | val_los

KeyboardInterrupt: 