In [None]:
"""
gnn_experiment_refactored.py
—————————————————————————————————
End-to-end GNN experimentation wrapper for **binary rare-event
classification** on temporal node snapshots.

Key additions vs. original version
• Per-node positive-class weights built into BCE loss
• Edge-dropout (train-time only)
• Configurable LR scheduler with LR print at each epoch
"""

# ----------------------------------------------------------------------
# 0. Imports
# ----------------------------------------------------------------------
import json
import pathlib
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional, Union, List, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GCNConv,
    GATv2Conv,
    GCN,
    GraphSAGE,
    GAT,
)


# ----------------------------------------------------------------------
# 1. Configuration object
# ----------------------------------------------------------------------
@dataclass
class GNNConfig:
    # --- model ---------------------------------------------------------
    task: str = "node_clf"            # {"node_clf", "edge_clf"}
    model_name: str = "gcn"           # {"gcn","graphsage","gat","gatv2"}
    num_layers: int = 2
    hidden_dim: int = 128
    heads: int = 8                    # for attention models
    norm: str = "batch"               # {"batch","layer",None}

    # --- data & label --------------------------------------------------
    target_col: str = "target"

    # --- split ---------------------------------------------------------
    split_mode: str = "date"          # {"date","ratio"}
    cutoff_date: Optional[str] = None
    val_ratio: float = 0.10
    test_ratio: float = 0.10
    shuffle_in_split: bool = False

    # --- optimisation --------------------------------------------------
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 100
    patience: int = 10
    grad_clip: float = 1.0

    optimiser: str = "adam"           # {"adam","adamw","sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # --- LR scheduler --------------------------------------------------
    lr_scheduler: Optional[str] = None         # {"step","plateau","cosine","onecycle"}
    lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
    print_lr_each_epoch: bool = True

    # --- loss / weights -----------------------------------------------
    loss_fn: str = "bce"              # {"bce","cross_entropy"}
    class_weights: Optional[List[float]] = None   # global BCE/CE weights
    node_pos_weights: Optional[List[float]] = None  # per-node list (len == #nodes)

    # --- regularisation -----------------------------------------------
    in_dropout: float = 0.0           # feature-level dropout
    layer_dropout: float = 0.5        # model internal dropout
    edge_dropout: float = 0.0         # probability of dropping each edge at train-time

    # --- misc ----------------------------------------------------------
    device: str = "cuda"
    run_name: str = "default_run"
    seed: int = 42

    # ------------------------------------------------------------------
    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any]]) -> "GNNConfig":
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)
        if isinstance(cfg, (str, pathlib.Path)):
            path = pathlib.Path(cfg)
            with open(path, "r", encoding="utf-8") as f:
                if path.suffix == ".json":
                    data = json.load(f)
                else:
                    raise ValueError("Unsupported file type for config path")
            return GNNConfig(**data)
        raise TypeError(f"Unsupported cfg type: {type(cfg)}")

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


# ----------------------------------------------------------------------
# 2. Utilities
# ----------------------------------------------------------------------
def has_nan(t: torch.Tensor) -> bool:
    return torch.isnan(t).any() or torch.isinf(t).any()


# ----------------------------------------------------------------------
# 3. Mini model zoo
# ----------------------------------------------------------------------
class GNNNodeSimple(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, p_drop: float):
        super().__init__()
        self.conv1 = GCNConv(d_in, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x, ei = data.x, data.edge_index
        x = torch.relu(self.conv1(x, ei))
        x = torch.relu(self.conv2(x, ei))
        x = self.dropout(x)
        return self.fc(x)


class GNNGCN(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, p_drop: float):
        super().__init__()
        self.body = GCN(d_in, hidden, layers, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNSage(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, p_drop: float):
        super().__init__()
        self.body = GraphSAGE(d_in, hidden, layers, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNGAT(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, heads: int, p_drop: float):
        super().__init__()
        self.body = GAT(d_in, hidden, layers, heads=heads, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNGAT2(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        hidden: int,
        d_out: int,
        layers: int,
        heads: int,
        edge_dim: int,
        p_drop: float,
    ):
        super().__init__()
        self.conv1 = GATv2Conv(d_in, hidden, heads=heads, concat=True, edge_dim=edge_dim)
        self.conv2 = GATv2Conv(hidden * heads, hidden, heads=1, concat=False, edge_dim=edge_dim)
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = torch.relu(self.conv1(data.x, data.edge_index, data.edge_attr))
        x = torch.relu(self.conv2(x, data.edge_index, data.edge_attr))
        x = self.dropout(x)
        return self.fc(x)


# ----------------------------------------------------------------------
# 4. Experiment wrapper
# ----------------------------------------------------------------------
class GNNExperiment:
    """
    Raw node/edge time-series → PyG snapshots → DataLoaders → Model.
    """

    # ------------------------------------------------------------ constructor
    def __init__(
        self,
        node_frames: Dict[str, pd.DataFrame],
        edge_frames: Dict[str, pd.DataFrame],
        graph: nx.DiGraph,
        cfg: Union[GNNConfig, str, Dict[str, Any]] = GNNConfig(),
    ):
        # raw inputs
        self.node_frames = node_frames
        self.edge_frames = edge_frames
        self.graph = graph

        # configuration
        self.cfg = GNNConfig.load(cfg)

        # runtime placeholders
        self.reg_order: List[str] = []
        self.edge_order: List[Tuple[int, int]] = []
        self.snapshots: Optional[List[Data]] = None

        self.train_dl = self.val_dl = self.test_dl = None

        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.loss_fn = None
        self.metric_fn = None

        self.history: Dict[str, list] = {"train_loss": [], "val_loss": [], "val_metric": []}
        self.best_val_loss = float("inf")
        self.best_ckpt = None
        self.patience_counter = 0

        # per-node weight lookup
        self._node_weight_lookup: Optional[torch.Tensor] = None

        # device
        self.device = torch.device(
            "cuda" if (self.cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
        )

    # ---------------------------------------------------------------- helpers
    @staticmethod
    def _edge_key(src: str, dst: str) -> str:
        return f"{src}-{dst}".lower()

    @staticmethod
    def _ensure_same_index(idxs: List[pd.DatetimeIndex]) -> List[pd.Timestamp]:
        common = sorted(set.intersection(*(set(ix) for ix in idxs)))
        if not common:
            raise ValueError("No common timestamps across supplied DataFrames")
        return common

    # ------------------------------------------------------- 1. snapshots
    def prepare_snapshots(self) -> "GNNExperiment":
        # node & edge order
        self.reg_order = sorted(self.graph.nodes)
        node_pos = {n: i for i, n in enumerate(self.reg_order)}
        self.edge_order = [(node_pos[s], node_pos[d]) for (s, d) in self.graph.edges]

        # node-weight lookup
        if self.cfg.node_pos_weights is not None:
            if len(self.cfg.node_pos_weights) != len(self.reg_order):
                raise ValueError("len(node_pos_weights) must equal number of nodes")
            self._node_weight_lookup = torch.tensor(self.cfg.node_pos_weights, dtype=torch.float32)
        else:
            self._node_weight_lookup = None

        # timestamps intersection
        ts_common = self._ensure_same_index(
            [df.index for df in self.node_frames.values()]
            + [df.index for df in self.edge_frames.values()]
        )

        snapshots: List[Data] = []
        for ts in ts_common:
            # node features + label
            feats, labels = [], []
            for region in self.reg_order:
                row = self.node_frames[region].loc[ts]
                labels.append(row[self.cfg.target_col])
                feats.append(row.drop(self.cfg.target_col).to_numpy(dtype=np.float32))
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)
            y = torch.tensor(labels, dtype=torch.float32)

            # edge features
            edge_rows = []
            for s_idx, d_idx in self.edge_order:
                s, d = self.reg_order[s_idx], self.reg_order[d_idx]
                edge_rows.append(
                    self.edge_frames[self._edge_key(s, d)].loc[ts].to_numpy(dtype=np.float32)
                )
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
            edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

            node_weight = (
                self._node_weight_lookup.clone()
                if self._node_weight_lookup is not None
                else torch.ones(len(self.reg_order))
            )

            snapshots.append(
                Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    node_weight=node_weight,
                    snap_time=torch.tensor([pd.Timestamp(ts).value]),
                )
            )

        self.snapshots = snapshots
        return self

    # ------------------------------------------------------- 2. loaders
    def build_loaders(self) -> "GNNExperiment":
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() before build_loaders()")

        snaps_sorted = sorted(self.snapshots, key=lambda g: g.snap_time.item())

        if self.cfg.split_mode == "date":
            if self.cfg.cutoff_date is None:
                raise ValueError("cutoff_date must be set for date split")
            cutoff_int = pd.Timestamp(self.cfg.cutoff_date).value
            train_set = [g for g in snaps_sorted if g.snap_time.item() <= cutoff_int]
            holdout = [g for g in snaps_sorted if g.snap_time.item() > cutoff_int]
            if not holdout:
                raise ValueError("No snapshots after cutoff_date for val/test split")
            mid = len(holdout) // 2
            val_set, test_set = holdout[:mid], holdout[mid:]
        else:
            n_total = len(snaps_sorted)
            n_test = int(n_total * self.cfg.test_ratio)
            n_val = int(n_total * self.cfg.val_ratio)
            n_train = n_total - n_val - n_test
            train_set = snaps_sorted[:n_train]
            val_set = snaps_sorted[n_train : n_train + n_val]
            test_set = snaps_sorted[n_train + n_val :]

        dl_kw = dict(batch_size=self.cfg.batch_size, shuffle=self.cfg.shuffle_in_split)
        self.train_dl = DataLoader(train_set, **dl_kw)
        self.val_dl = DataLoader(val_set, batch_size=self.cfg.batch_size, shuffle=False)
        self.test_dl = DataLoader(test_set, batch_size=self.cfg.batch_size, shuffle=False)
        return self

    # ------------------------------------------------------- 3. model
    def init_model(self) -> "GNNExperiment":
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() first.")

        d_in = self.snapshots[0].x.size(1)
        d_out = 1  # binary logit
        edge_dim = self.snapshots[0].edge_attr.size(1)
        p_drop = self.cfg.layer_dropout

        name = self.cfg.model_name.lower()
        if name in {"simple", "baseline"}:
            model = GNNNodeSimple(d_in, self.cfg.hidden_dim, d_out, p_drop)
        elif name == "gcn":
            model = GNNGCN(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, p_drop)
        elif name in {"graphsage", "sage"}:
            model = GNNSage(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, p_drop)
        elif name == "gat":
            model = GNNGAT(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, self.cfg.heads, p_drop)
        elif name in {"gatv2", "gat2"}:
            model = GNNGAT2(
                d_in,
                self.cfg.hidden_dim,
                d_out,
                self.cfg.num_layers,
                self.cfg.heads,
                edge_dim,
                p_drop,
            )
        else:
            raise ValueError(f"Unknown model_name '{self.cfg.model_name}'")

        self.model = model.to(self.device)
        return self

    # ------------------------------------------------------- 4. compile
    def compile(self) -> "GNNExperiment":
        if self.model is None:
            raise RuntimeError("Call init_model() before compile()")

        # ---- loss ------------------------------------------------------
        if self.cfg.loss_fn.lower() in {"bce", "binary"}:
            self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
            self.metric_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
        elif self.cfg.loss_fn.lower() == "cross_entropy":
            self.loss_fn = torch.nn.CrossEntropyLoss()
            self.metric_fn = self.loss_fn
        else:
            raise ValueError(f"Unsupported loss_fn '{self.cfg.loss_fn}'")

        # ---- optimiser -------------------------------------------------
        opt_kw = dict(lr=self.cfg.lr, **self.cfg.optimiser_kwargs)
        opt_name = self.cfg.optimiser.lower()
        if opt_name == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(), **opt_kw)
        elif opt_name == "adamw":
            self.optimizer = torch.optim.AdamW(self.model.parameters(), **opt_kw)
        elif opt_name == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(), **opt_kw)
        else:
            raise ValueError(f"Unknown optimiser '{self.cfg.optimiser}'")

        # ---- scheduler -------------------------------------------------
        self.scheduler = None
        if self.cfg.lr_scheduler is not None:
            sname = self.cfg.lr_scheduler.lower()
            kw = self.cfg.lr_scheduler_kwargs
            if sname == "step":
                self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, **kw)
            elif sname == "cosine":
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, **kw)
            elif sname == "plateau":
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **kw)
            elif sname == "onecycle":
                self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, **kw)
            else:
                raise ValueError(f"Unknown lr_scheduler '{self.cfg.lr_scheduler}'")

        # track last LR for pretty print
        self._last_lr = self.optimizer.param_groups[0]["lr"]
        return self

    # ------------------------------------------------ edge-dropout helper
    def _apply_edge_dropout(self, data: Data) -> Data:
        p = getattr(self.cfg, "edge_dropout", 0.0)
        if p <= 0.0 or not self.model.training:
            return data
        E = data.edge_index.size(1)
        keep_mask = torch.rand(E, device=data.edge_index.device) >= p
        if keep_mask.all():
            return data
        data.edge_index = data.edge_index[:, keep_mask]
        if data.edge_attr is not None:
            data.edge_attr = data.edge_attr[keep_mask]
        return data

    # ------------------------------------------------ train-epoch helper
    def _train_epoch(self, debug: bool = False) -> float:
        self.model.train()
        total = 0.0
        for step, data in enumerate(self.train_dl):
            data = data.to(self.device)
            data = self._apply_edge_dropout(data)

            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()

            self.optimizer.zero_grad()
            logits = self.model(data).squeeze(-1)
            loss_vec = self.loss_fn(logits, data.y)
            loss_vec = loss_vec * data.node_weight.to(loss_vec.device)
            loss = loss_vec.mean()

            if debug and (has_nan(loss) or has_nan(logits)):
                raise RuntimeError(f"NaN detected at step {step}")

            loss.backward()
            if self.cfg.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
            self.optimizer.step()

            total += loss.item()
        return total / len(self.train_dl)

    # ------------------------------------------------ eval helper
    @torch.no_grad()
    def _eval_loader(self, loader: DataLoader) -> Tuple[float, float]:
        self.model.eval()
        total_loss = total_metric = 0.0
        for data in loader:
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()

            logits = self.model(data).squeeze(-1)
            loss_vec = self.loss_fn(logits, data.y)
            metric_vec = self.metric_fn(logits, data.y)

            w = data.node_weight.to(loss_vec.device)
            total_loss += (loss_vec * w).mean().item()
            total_metric += (metric_vec * w).mean().item()
        n = len(loader)
        return total_loss / n, total_metric / n

    # ------------------------------------------------ main training loop
    def train(self, debug: bool = False):
        if any(v is None for v in (self.model, self.optimizer, self.loss_fn)):
            raise RuntimeError("Call init_model() and compile() before train()")

        for epoch in range(1, self.cfg.epochs + 1):
            tr_loss = self._train_epoch(debug)
            val_loss, val_metric = self._eval_loader(self.val_dl)

            # scheduler
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

            # LR print
            if self.cfg.print_lr_each_epoch:
                curr_lr = self.optimizer.param_groups[0]["lr"]
                lr_msg = f"LR changed → {curr_lr:.2e}" if curr_lr != self._last_lr else f"LR={curr_lr:.2e}"
                self._last_lr = curr_lr
            else:
                lr_msg = ""

            # store history
            self.history["train_loss"].append(tr_loss)
            self.history["val_loss"].append(val_loss)
            self.history["val_metric"].append(val_metric)

            print(
                f"[Epoch {epoch:03d}/{self.cfg.epochs}] "
                f"train={tr_loss:.4f}  val={val_loss:.4f}  "
                f"metric={val_metric:.4f}  {lr_msg}"
            )

            # early stopping
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                self.best_ckpt = {k: v.cpu() for k, v in self.model.state_dict().items()}
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.cfg.patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
        return self

    # ------------------------------------------------ evaluate
    def evaluate(self, split: str = "test") -> Dict[str, float]:
        loader = {"train": self.train_dl, "val": self.val_dl, "test": self.test_dl}.get(split)
        if loader is None:
            raise ValueError("split must be 'train', 'val' or 'test'")
        loss, metric = self._eval_loader(loader)
        return {"loss": loss, "metric": metric}

    # ------------------------------------------------ predict
    @torch.no_grad()
    def predict(
        self,
        node_frames_new: Dict[str, pd.DataFrame],
        edge_frames_new: Dict[str, pd.DataFrame],
        timestamps: Optional[List[pd.Timestamp]] = None,
        return_df: bool = True,
    ):
        if self.model is None:
            raise RuntimeError("Train or load a model before calling predict()")

        # decide timestamps
        if timestamps is None:
            idxs = (
                [df.index for df in node_frames_new.values()]
                + [df.index for df in edge_frames_new.values()]
            )
            timestamps = sorted(set.intersection(*(set(ix) for ix in idxs)))
        else:
            timestamps = [pd.Timestamp(ts) for ts in timestamps]

        snaps = []
        for ts in timestamps:
            feats = []
            for region in self.reg_order:
                feats.append(
                    node_frames_new[region]
                    .loc[ts]
                    .drop(self.cfg.target_col)
                    .to_numpy(dtype=np.float32)
                )
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)

            edge_rows = []
            for s_idx, d_idx in self.edge_order:
                s, d = self.reg_order[s_idx], self.reg_order[d_idx]
                edge_rows.append(
                    edge_frames_new[f"{s}-{d}".lower()].loc[ts].to_numpy(dtype=np.float32)
                )
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
            edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

            snaps.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr))

        loader = DataLoader(snaps, batch_size=self.cfg.batch_size, shuffle=False)
        self.model.eval()
        preds = []
        for data in loader:
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()
            logits = self.model(data).squeeze(-1).cpu()
            preds.append(logits)

        y_hat = torch.cat(preds, dim=0)           # (len(ts)*N)
        y_hat = y_hat.reshape(len(timestamps), len(self.reg_order))

        if return_df:
            return pd.DataFrame(
                y_hat.numpy(),
                index=pd.to_datetime(timestamps),
                columns=self.reg_order,
            )
        return y_hat

    # ------------------------------------------------ plot history
    def plot_history(self, metric: str = "loss"):
        if not self.history["train_loss"]:
            raise RuntimeError("Nothing in history – did you call train()?")

        if metric == "loss":
            plt.figure(figsize=(6, 4))
            plt.plot(self.history["train_loss"], label="Train")
            plt.plot(self.history["val_loss"], label="Val")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.title("Training / Validation Loss")
            plt.legend()
            plt.grid(True)
            plt.show()

        elif metric in {"metric", "mae", "accuracy"}:
            if not self.history["val_metric"]:
                raise ValueError("Metric history empty; choose 'loss' instead.")
            plt.figure(figsize=(6, 4))
            plt.plot(self.history["val_metric"], label=f"Val {metric}")
            plt.xlabel("Epoch")
            plt.ylabel(metric.upper())
            plt.title(f"Validation {metric.upper()} Curve")
            plt.legend()
            plt.grid(True)
            plt.show()
        else:
            raise ValueError("metric must be 'loss' or 'metric'")


### Rare event detection 

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from statsmodels.nonparametric.smoothers_lowess import lowess  # pip install statsmodels

def prob_vs_continuous(
    df_proba: pd.DataFrame,
    df_true_cont: pd.DataFrame,
    node: str,
    cutoff: float,
    max_points: int = 5000,
    lowess_frac: float = 0.15,
):
    """
    Scatter + LOWESS smoothed curve of predicted probability vs. true continuous value.

    Parameters
    ----------
    df_proba      : DataFrame of predicted probabilities (shape [t, nodes])
    df_true_cont  : DataFrame of true continuous target (same shape / index / columns)
    node          : Which node/column to plot
    cutoff        : Threshold that defines the binary label
    max_points    : Sub-sample size for scatter to keep the plot light
    lowess_frac   : Span for LOWESS smoothing (between 0 and 1)
    """
    p = df_proba[node].copy()
    x = df_true_cont[node].copy()

    # Sub-sample for scatter if very large
    if len(p) > max_points:
        idx = np.random.choice(len(p), size=max_points, replace=False)
        p = p.iloc[idx]
        x = x.iloc[idx]

    plt.figure(figsize=(6, 4))
    sns.scatterplot(x=x, y=p, alpha=0.3, s=15, edgecolor="none")

    # LOWESS smooth
    sm = lowess(p, x, frac=lowess_frac, return_sorted=True)
    plt.plot(sm[:, 0], sm[:, 1], color="red", lw=2, label="LOWESS avg")

    # cutoff line
    plt.axvline(cutoff, color="gray", ls="--", label=f"cutoff = {cutoff}")

    plt.xlabel("True continuous value")
    plt.ylabel("Predicted probability")
    plt.title(f"{node}: p̂ vs. true value")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Suppose df_logits and df_true_continuous share index & columns
df_proba = 1 / (1 + np.exp(-df_logits))       # convert to probability
prob_vs_continuous(
    df_proba,
    df_true_continuous,
    node="Tokyo",
    cutoff=0.10,
)

### Comparing to baseline 

In [None]:
# logistic_threshold_noscale.py
import pandas as pd
import numpy as np
from typing import List, Tuple, Dict

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
import matplotlib.pyplot as plt


# ────────────────────────  helpers ────────────────────────
def build_lagged_features(df: pd.DataFrame,
                          target_col: str,
                          lags: List[int]) -> pd.DataFrame:
    """
    Adds lag columns target_col_lag{k} for each k in lags.
    """
    out = df.copy()
    for k in lags:
        out[f"{target_col}_lag{k}"] = out[target_col].shift(k)
    return out


def plot_pr_curve(y_true: np.ndarray,
                  y_proba: np.ndarray) -> None:
    """
    Precision-Recall plot with average-precision (AP) in the title.
    """
    prec, rec, _ = precision_recall_curve(y_true, y_proba)
    ap = average_precision_score(y_true, y_proba)

    plt.figure()
    plt.step(rec, prec, where="post")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision–Recall curve  (AP = {ap:.3f})")
    plt.ylim(0, 1.05)
    plt.xlim(0, 1.0)
    plt.grid(alpha=.3)
    plt.show()


# ────────────────────────  main training routine ────────────────────────
def train_logistic_model(df: pd.DataFrame,
                         target_col: str = "value",
                         threshold: float = 10.0,
                         lags: List[int] = (1, 2, 3),
                         test_size: float = 0.2,
                         class_weight: Dict[int, float] | str | None = None,
                         solver: str = "lbfgs",
                         max_iter: int = 1000,
                         random_state: int | None = 42
                         ) -> Tuple[LogisticRegression, dict]:
    """
    Fits a LogisticRegression without scaling.
    • df must have a DateTime-like index already sorted ascending.
    • class_weight can be a dict such as {0: 1, 1: 200} or "balanced".
    Returns (fitted_model, metrics_dict).
    """

    # 1 . Feature engineering – add lags
    data = build_lagged_features(df, target_col, lags)

    # 2 . Binary label: 1 if below threshold
    data["y"] = (data[target_col] < threshold).astype(int)

    # 3 . Drop rows made NaN by shifting
    data = data.dropna()

    # 4 . Time-based split (no leakage!)
    split_idx = int(len(data) * (1 - test_size))
    train, test = data.iloc[:split_idx], data.iloc[split_idx:]

    X_train = train.drop(columns=[target_col, "y"])
    y_train = train["y"]
    X_test = test.drop(columns=[target_col, "y"])
    y_test = test["y"]

    # 5 . Fit logistic model
    clf = LogisticRegression(
        class_weight=class_weight,
        solver=solver,
        max_iter=max_iter,
        random_state=random_state,
    )
    clf.fit(X_train, y_train)

    # 6 . Evaluation focused on PR
    y_proba = clf.predict_proba(X_test)[:, 1]
    ap = average_precision_score(y_test, y_proba)
    prec, rec, thr = precision_recall_curve(y_test, y_proba)

    metrics = {
        "average_precision": ap,
        "precision": prec,
        "recall": rec,
        "thresholds": thr,
        "y_test": y_test.to_numpy(),
        "y_proba": y_proba,
    }

    return clf, metrics
How you might call it
python
Copy
Edit
# Example imbalance: positives ≈ 0.5 %, weight them 200× higher
cw = {0: 1, 1: 200}

model, metrics = train_logistic_model(
    df,                          # your DateTime-indexed DataFrame
    target_col="value",
    threshold=10,
    lags=[1, 2, 3, 6, 12],       # whatever lags make sense
    test_size=0.25,
    class_weight=cw
)

### Quantile regression 

In [None]:
# quantile_regression_gbr.py
import pandas as pd
import numpy as np
from typing import List, Tuple

from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_pinball_loss
import matplotlib.pyplot as plt


# ────────────────────────  helpers ────────────────────────
def build_lagged_features(df: pd.DataFrame,
                          target_col: str,
                          lags: List[int]) -> pd.DataFrame:
    """
    Adds lag columns target_col_lag{k} for each k in lags.
    """
    out = df.copy()
    for k in lags:
        out[f"{target_col}_lag{k}"] = out[target_col].shift(k)
    return out


def plot_actual_vs_pred(y_true: np.ndarray,
                        y_pred: np.ndarray,
                        alpha: float) -> None:
    """
    Quick scatter plot to eyeball calibration:
    points should lie mostly *above* the diagonal for lower-tail quantiles.
    """
    plt.figure()
    plt.scatter(y_true, y_pred, s=10, alpha=0.3)
    diag_min, diag_max = y_true.min(), y_true.max()
    plt.plot([diag_min, diag_max], [diag_min, diag_max],
             linestyle='--', linewidth=1)
    plt.xlabel("Actual")
    plt.ylabel(f"Predicted Q{int(alpha*100)}")
    plt.title(f"Actual vs Predicted α-Quantile (α={alpha})")
    plt.grid(alpha=.3)
    plt.show()


# ────────────────────────  main training routine ────────────────────────
def train_quantile_model(df: pd.DataFrame,
                         target_col: str = "value",
                         alpha: float = 0.05,
                         lags: List[int] = (1, 2, 3),
                         test_size: float = 0.2,
                         n_estimators: int = 500,
                         learning_rate: float = 0.03,
                         max_depth: int = 3,
                         random_state: int | None = 42
                         ) -> Tuple[GradientBoostingRegressor, dict]:
    """
    Fits a GradientBoostingRegressor with quantile loss.
    • df must have a DateTime-like index already sorted ascending.
    Returns (fitted_model, metrics_dict).
    """

    # 1 . Feature engineering – add lags
    data = build_lagged_features(df, target_col, lags)

    # 2 . Drop rows made NaN by shifting
    data = data.dropna()

    # 3 . Time-based split (no leakage!)
    split_idx = int(len(data) * (1 - test_size))
    train, test = data.iloc[:split_idx], data.iloc[split_idx:]

    X_train = train.drop(columns=[target_col])
    y_train = train[target_col]
    X_test = test.drop(columns=[target_col])
    y_test = test[target_col]

    # 4 . Fit quantile GBRT
    gbr = GradientBoostingRegressor(
        loss="quantile",
        alpha=alpha,
        n_estimators=n_estimators,
        learning_rate=learning_rate,
        max_depth=max_depth,
        random_state=random_state
    )
    gbr.fit(X_train, y_train)

    # 5 . Evaluation
    y_pred = gbr.predict(X_test)
    pinball = mean_pinball_loss(y_test, y_pred, alpha=alpha)
    coverage = (y_test < y_pred).mean()     # empirical P(actual < predicted)

    metrics = {
        "pinball_loss": pinball,
        "coverage": coverage,
        "y_test": y_test.to_numpy(),
        "y_pred": y_pred,
    }

    return gbr, metrics


In [None]:
from quantile_regression_gbr import train_quantile_model, plot_actual_vs_pred

# Suppose 'df' is your DateTime-indexed DataFrame with column 'value'
model, metrics = train_quantile_model(
    df,
    target_col="value",
    alpha=0.01,           # 1 % lower-tail quantile
    lags=[1, 2, 3, 6, 12],
    test_size=0.25
)

print("Pinball Loss:", round(metrics["pinball_loss"], 6))
print("Empirical coverage:", round(metrics["coverage"], 4))  # should be ≈ 0.01

plot_actual_vs_pred(metrics["y_test"], metrics["y_pred"], alpha=0.01)

### Quantile regression 

In [None]:
# quantile_linear_regression.py
import pandas as pd
import numpy as np
from typing import List, Tuple

import statsmodels.api as sm
from sklearn.metrics import mean_pinball_loss
import matplotlib.pyplot as plt


# ────────────────────────  helpers ────────────────────────
def build_lagged_features(df: pd.DataFrame,
                          target_col: str,
                          lags: List[int]) -> pd.DataFrame:
    """
    Adds lag columns target_col_lag{k} for each k in lags.
    """
    out = df.copy()
    for k in lags:
        out[f"{target_col}_lag{k}"] = out[target_col].shift(k)
    return out


def plot_actual_vs_pred(y_true: np.ndarray,
                        y_pred: np.ndarray,
                        alpha: float) -> None:
    """
    Scatter plot to eyeball calibration of the predicted α-quantile.
    """
    plt.figure()
    plt.scatter(y_true, y_pred, s=14, alpha=0.3)
    lims = [y_true.min(), y_true.max()]
    plt.plot(lims, lims, "--", linewidth=1)
    plt.xlabel("Actual")
    plt.ylabel(f"Predicted Q{int(alpha*100)}")
    plt.title(f"Linear Quantile Regression (α={alpha})")
    plt.grid(alpha=.3)
    plt.show()


# ─────────────────────  main training routine  ───────────────────────────
def train_quantile_linear_model(df: pd.DataFrame,
                                target_col: str = "value",
                                alpha: float = 0.05,
                                lags: List[int] = (1, 2, 3),
                                test_size: float = 0.2,
                                max_iter: int = 5000,
                                p_tol: float = 1e-5,
                                random_state: int | None = 42
                                ) -> Tuple[sm.regression.linear_model.RegressionResultsWrapper, dict]:
    """
    Fits a *linear* quantile regression (statsmodels.QuantReg).
    • df must have a DateTime-like index already sorted ascending.
    Returns (fitted_results, metrics_dict).
    """

    # 1 . Feature engineering – add lags
    data = build_lagged_features(df, target_col, lags)

    # 2 . Drop NaNs caused by shifting
    data = data.dropna()

    # 3 . Time-based split
    split_idx = int(len(data) * (1 - test_size))
    train, test = data.iloc[:split_idx], data.iloc[split_idx:]

    X_train = train.drop(columns=[target_col])
    y_train = train[target_col]
    X_test = test.drop(columns=[target_col])
    y_test = test[target_col]

    # 4 . Add constant term
    X_train_c = sm.add_constant(X_train, has_constant="add")
    X_test_c = sm.add_constant(X_test, has_constant="add")

    # 5 . Fit linear quantile regression
    qr_mod = sm.QuantReg(y_train, X_train_c)
    qr_res = qr_mod.fit(q=alpha,
                        max_iter=max_iter,
                        p_tol=p_tol,
                        disp=False)

    # 6 . Predict and evaluate
    y_pred = qr_res.predict(X_test_c)
    pinball = mean_pinball_loss(y_test, y_pred, alpha=alpha)
    coverage = (y_test < y_pred).mean()

    metrics = {
        "pinball_loss": pinball,
        "coverage": coverage,
        "y_test": y_test.to_numpy(),
        "y_pred": y_pred.to_numpy(),
    }

    return qr_res, metrics

In [None]:
from quantile_linear_regression import (
    train_quantile_linear_model,
    plot_actual_vs_pred,
)

# df is a DateTime-indexed DataFrame with column 'value'
results, metrics = train_quantile_linear_model(
    df,
    target_col="value",
    alpha=0.01,          # 1 % lower-tail quantile
    lags=[1, 2, 3, 6, 12],
    test_size=0.25
)

print(results.summary())                 # full regression table
print("Pinball loss:", metrics["pinball_loss"])
print("Empirical coverage:", metrics["coverage"])  # ≈ 0.01 if well-calibrated

plot_actual_vs_pred(metrics["y_test"], metrics["y_pred"], alpha=0.01)

### alll rewrote 

In [None]:
import pandas as pd
import numpy as np
from typing import Tuple, Dict, Union

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
import matplotlib.pyplot as plt


# ───────────────────────── helpers ──────────────────────────
def plot_pr_curve(y_true: np.ndarray, y_proba: np.ndarray) -> None:
    prec, rec, _ = precision_recall_curve(y_true, y_proba)
    ap = average_precision_score(y_true, y_proba)

    plt.figure()
    plt.step(rec, prec, where="post")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision–Recall curve  (AP = {ap:.3f})")
    plt.ylim(0, 1.05)
    plt.xlim(0, 1.0)
    plt.grid(alpha=.3)
    plt.show()


# ───────────────────── main training routine ─────────────────────
def train_logistic_model(
    X: pd.DataFrame,
    y: pd.Series,
    test_size: float = 0.2,
    class_weight: Union[Dict[int, float], str, None] = None,
    solver: str = "lbfgs",
    max_iter: int = 1000,
    random_state: int | None = 42,
) -> Tuple[LogisticRegression, dict]:
    """
    Train a class-weighted LogisticRegression (no scaling, no feature engineering).

    Parameters
    ----------
    X : DataFrame
        Feature matrix with a DateTime-like, ascending index.
    y : Series
        0/1 labels aligned on the same index as X.
    test_size : float
        Fraction of the most-recent rows kept for testing (time-ordered split).
    class_weight : dict | "balanced" | None
        Passed straight to sklearn’s LogisticRegression.
    """

    if not X.index.is_monotonic_increasing:
        raise ValueError("X must be sorted in chronological order (oldest → newest).")

    # 1. Time-based split
    split_idx = int(len(X) * (1 - test_size))
    X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
    y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

    # 2. Fit model
    clf = LogisticRegression(
        class_weight=class_weight,
        solver=solver,
        max_iter=max_iter,
        random_state=random_state,
    ).fit(X_train, y_train)

    # 3. Evaluate
    y_proba = clf.predict_proba(X_test)[:, 1]
    ap = average_precision_score(y_test, y_proba)
    prec, rec, thr = precision_recall_curve(y_test, y_proba)

    metrics = {
        "average_precision": ap,
        "precision": prec,
        "recall": rec,
        "thresholds": thr,
        "y_test": y_test.to_numpy(),
        "y_proba": y_proba,
    }

    return clf, metrics


In [None]:
import pandas as pd
import numpy as np
from typing import Tuple

from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_pinball_loss
import matplotlib.pyplot as plt


def plot_actual_vs_pred(y_true: np.ndarray, y_pred: np.ndarray, alpha: float) -> None:
    plt.figure()
    plt.scatter(y_true, y_pred, s=10, alpha=0.3)
    lims = [y_true.min(), y_true.max()]
    plt.plot(lims, lims, "--", linewidth=1)
    plt.xlabel("Actual")
    plt.ylabel(f"Predicted Q{int(alpha*100)}")
    plt.title(f"GBR Quantile α={alpha}")
    plt.grid(alpha=.3)
    plt.show()


def train_quantile_gbr(
    X: pd.DataFrame,
    y: pd.Series,
    alpha: float = 0.05,
    test_size: float = 0.2,
    n_estimators: int = 500,
    learning_rate: float = 0.03,
    max_depth: int = 3,
    random_state: int | None = 42,
) -> Tuple[GradientBoostingRegressor, dict]:
    """
    Tree-based quantile regression without any automatic feature engineering.
    """

    if not X.index.is_monotonic_increasing:
        raise ValueError("X must be in chronological order.")

    split_idx = int(len(X) * (1 - test_size))
    X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
    y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

    gbr = GradientBoostingRegressor(
        loss="quantile",
        alpha=alpha,
        n_estimators=n_estimators,
        learning_rate=learning_rate,
        max_depth=max_depth,
        random_state=random_state,
    ).fit(X_train, y_train)

    y_pred = gbr.predict(X_test)
    pinball = mean_pinball_loss(y_test, y_pred, alpha=alpha)
    coverage = (y_test < y_pred).mean()

    metrics = {
        "pinball_loss": pinball,
        "coverage": coverage,
        "y_test": y_test.to_numpy(),
        "y_pred": y_pred,
    }

    return gbr, metrics


In [None]:
import pandas as pd
import numpy as np
from typing import Tuple

import statsmodels.api as sm
from sklearn.metrics import mean_pinball_loss
import matplotlib.pyplot as plt


def plot_actual_vs_pred(y_true: np.ndarray, y_pred: np.ndarray, alpha: float) -> None:
    plt.figure()
    plt.scatter(y_true, y_pred, s=14, alpha=0.3)
    lims = [y_true.min(), y_true.max()]
    plt.plot(lims, lims, "--", linewidth=1)
    plt.xlabel("Actual")
    plt.ylabel(f"Predicted Q{int(alpha*100)}")
    plt.title(f"Linear Quantile α={alpha}")
    plt.grid(alpha=.3)
    plt.show()


def train_quantile_linear(
    X: pd.DataFrame,
    y: pd.Series,
    alpha: float = 0.05,
    test_size: float = 0.2,
    max_iter: int = 5000,
    p_tol: float = 1e-5,
) -> Tuple[sm.regression.linear_model.RegressionResultsWrapper, dict]:
    """
    statsmodels QuantReg on pre-prepared features.
    """

    if not X.index.is_monotonic_increasing:
        raise ValueError("X must be in chronological order.")

    split_idx = int(len(X) * (1 - test_size))
    X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
    y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

    X_train_c = sm.add_constant(X_train, has_constant="add")
    X_test_c = sm.add_constant(X_test, has_constant="add")

    qr_res = (
        sm.QuantReg(y_train, X_train_c)
        .fit(q=alpha, max_iter=max_iter, p_tol=p_tol, disp=False)
    )

    y_pred = qr_res.predict(X_test_c)
    pinball = mean_pinball_loss(y_test, y_pred, alpha=alpha)
    coverage = (y_test < y_pred).mean()

    metrics = {
        "pinball_loss": pinball,
        "coverage": coverage,
        "y_test": y_test.to_numpy(),
        "y_pred": y_pred.to_numpy(),
    }

    return qr_res, metrics
