In [None]:
from typing import Tuple

import cf_xarray as cfxr
import matplotlib.pyplot as plt
import plotly.express as px
import polars as pl
import torch
import xarray as xr
import numpy as np
from model import SharedPerLocationSum

In [None]:
features = [
    "lt",
    "operating_power_max",
    "mean_production",
    "num_turbines",
    "ELSPOT NO1",
    "ELSPOT NO2",
    "ELSPOT NO3",
    "ELSPOT NO4",
    "last_day_mean",
    "last_value",
    "ws10m_00",
    "wd10m_00",
    "t2m_00",
    "rh2m_00",
    "mslp_00",
    "g10m_00",
    "ws10m_mean",
    "t2m_mean",
    "rh2m_mean",
    "mslp_mean",
    "g10m_mean",
    "ws10m_std",
    "t2m_std",
    "rh2m_std",
    "mslp_std",
    "g10m_std",
    "now_air_temperature_2m",
    "now_air_pressure_at_sea_level",
    "now_relative_humidity_2m",
    "now_precipitation_amount",
    "now_wind_speed_10m",
    "now_wind_direction_10m",
    "sin_hod",
    "cos_hod",
    "sin_doy",
    "cos_doy",
    "sin_dow",
    "cos_dow",
]

In [None]:
def main(ckpt_path, data_path, var_idx=2, n_points=200, plot_scaled=False):
    device = "cpu"

    # --- Load checkpoint & rebuild model ---
    ckpt = torch.load(ckpt_path, map_location=device)
    model_kwargs = ckpt.get(
        "model_kwargs", {"in_dim": 7, "hidden": (64, 32), "dropout": 0.0}
    )
    model_kwargs["width"] = model_kwargs["hidden"]
    model_kwargs.pop("hidden")
    V = model_kwargs["in_dim"]
    model = SharedPerLocationSum(**model_kwargs).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    # Normalization stats (may be None if you trained without normalization)
    normalize_x = bool(ckpt.get("normalize_x", True))
    normalize_y = bool(ckpt.get("normalize_y", True))
    x_mean = ckpt.get("x_mean", None)
    x_std = ckpt.get("x_std", None)
    y_mean = ckpt.get("y_mean", None)
    y_std = ckpt.get("y_std", None)

    if x_mean is not None:
        x_mean = x_mean.to(device)  # shape (1, V)
    if x_std is not None:
        x_std = x_std.to(device)
    if y_mean is not None:
        y_mean = y_mean.to(device)  # scalars
    if y_std is not None:
        y_std = y_std.to(device)

    # --- Load data to extract empirical distribution for var_idx ---
    blob = torch.load(data_path, map_location=device)
    X = blob["X"].float()  # (N, L, V)

    # Flatten over time+location for variable var_idx to get its distribution in raw units
    x_var = X[..., var_idx].reshape(-1)  # (N*L,)
    # Use 1%..99% quantiles to avoid extreme tails
    qs = torch.linspace(0.01, 0.99, steps=n_points)
    x_values = torch.quantile(x_var, qs).to(device)  # (n_points,)

    # --- Build an input batch for φ evaluation ---
    # We'll evaluate φ on single-location inputs (L=1) since φ is per-location.
    L_eval = 1
    base = torch.zeros(n_points, V, device=device)

    # Keep other variables at the training mean if available, else dataset mean
    if x_mean is not None:
        base[:] = x_mean  # broadcast (1,V) -> (n_points,V)
    else:
        # fallback to dataset mean across (N,L) for each V
        dataset_mean = X.mean(dim=(0, 1), keepdim=False).to(device)  # (V,)
        base[:] = dataset_mean

    # Replace var_idx with the grid of values
    base[:, var_idx] = x_values

    # Apply input normalization if used during training
    if normalize_x and (x_mean is not None) and (x_std is not None):
        base_norm = (base - x_mean) / (x_std + 1e-6)
    else:
        base_norm = base

    # --- Evaluate φ directly on (n_points, V) ---
    with torch.no_grad():
        phi_vals = model.phi(base_norm).squeeze(-1).cpu().numpy()  # (n_points,)

    # Optionally scale by y_std to bring into original target units (still missing global +y_mean)
    if plot_scaled and normalize_y and (y_std is not None):
        y_scale = float(y_std.cpu().numpy())
        y_to_plot = phi_vals * y_scale
        ylabel = "Per-location contribution φ(x) [approx. target units (× y_std)]"
    else:
        y_to_plot = phi_vals
        ylabel = "Per-location contribution φ(x) [normalized target units]"

    # --- Plot ---
    xv = x_values.cpu().numpy()
    plt.figure(figsize=(7, 4.5))
    plt.plot(xv, y_to_plot, "-o")
    # plt.xlabel(f"Variable {var_idx} (raw units)")
    plt.xlabel(features[var_idx])
    plt.ylabel(ylabel)
    plt.title("Learned per-location φ response")
    plt.tight_layout()

In [None]:
ckpt_path = "checkpoints/wind_residual_last.pth"
data_path = "data/torch_dataset_all_zones.pt"
var_idx = 10
n_points = 200
plot_scaled = False
main(ckpt_path, data_path, var_idx, n_points, plot_scaled)

In [None]:
ckpt = torch.load(ckpt_path)
model_kwargs = ckpt.get("model_kwargs")
model_kwargs["width"] = model_kwargs["hidden"]
model_kwargs.pop("hidden")
model = SharedPerLocationSum(**model_kwargs)
model.load_state_dict(ckpt["state_dict"])
model.eval()

In [None]:
@torch.no_grad()
def scatter_phi_vs_feature_from_data(
    ckpt_path: str,
    data_path: str,
    var_idx: int = 2,
    color_idx: int = None,
    num_pairs: int = 50000,
    plot_scaled: bool = False,
    overlay_median: bool = True,
    num_bins: int = 30,
    seed: int = 0,
    figsize: Tuple[int, int] = (7, 4),
    # NEW: missing-data handling
    drop_missing: bool = True,
    missing_idx: int = 1,
    missing_value: float = 0.0,
):
    device = "cpu"
    torch.manual_seed(seed)

    # ----- load checkpoint/model
    ckpt = torch.load(ckpt_path, map_location=device)
    model_kwargs = ckpt.get(
        "model_kwargs",  # {"in_dim": 11, "hidden": (64, 32), "dropout": 0.0}
    )
    model_kwargs["width"] = model_kwargs["hidden"]
    model_kwargs.pop("hidden")
    V = model_kwargs["in_dim"]
    model = SharedPerLocationSum(**model_kwargs).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    # norms
    normalize_x = bool(ckpt.get("normalize_x", True))
    normalize_y = bool(ckpt.get("normalize_y", True))
    x_mean = ckpt.get("x_mean", None)
    x_std = ckpt.get("x_std", None)
    y_std = ckpt.get("y_std", None)
    if x_mean is not None:
        x_mean = x_mean.to(device)
    if x_std is not None:
        x_std = x_std.to(device)
    if y_std is not None:
        y_std = y_std.to(device)

    # ----- load data
    blob = torch.load(data_path, map_location=device)
    X = blob["X"].float()  # (N, L, V)
    N, L, Vx = X.shape
    assert Vx == V, f"Data V={Vx} != model V={V}"

    # ----- build valid (t, l) index set
    if drop_missing:
        valid_mask = X[..., missing_idx] != missing_value  # (N, L)
    else:
        valid_mask = torch.ones(N, L, dtype=torch.bool)

    valid_pairs = valid_mask.nonzero(as_tuple=False)  # (M, 2)
    M = valid_pairs.shape[0]
    if M == 0:
        raise ValueError(
            "No valid (timestamp, location) pairs after missing-value filter."
        )

    # sample K pairs uniformly from valid set
    K = min(num_pairs, M)
    sel = torch.randint(0, M, (K,))
    tl = valid_pairs[sel]
    t_idx, l_idx = tl[:, 0], tl[:, 1]

    # gather raw features for plotting x-axis
    X_pairs_raw = X[t_idx, l_idx, :]  # (K, V)
    x_feat = X_pairs_raw[:, var_idx].clone()  # raw units
    if color_idx is not None:
        color_feat = X_pairs_raw[:, color_idx].clone()

    # normalize inputs if used in training
    if normalize_x and (x_mean is not None) and (x_std is not None):
        X_pairs = (X_pairs_raw - x_mean) / (x_std + 1e-6)
    else:
        X_pairs = X_pairs_raw

    # evaluate φ
    phi_vals = model.phi(X_pairs).squeeze(-1)  # (K,)
    if plot_scaled and normalize_y and (y_std is not None):
        phi_vals = phi_vals * y_std  # approx target units (no +y_mean)

    # # ----- plot
    # plt.figure(figsize=figsize)
    # plt.scatter(
    #     x_feat.cpu().numpy(), phi_vals.cpu().numpy(), s=6, alpha=0.15, linewidths=0
    # )
    # # plt.hexbin(x_feat.cpu().numpy(), phi_vals.cpu().numpy(), gridsize=100, bins="log")
    # # plt.ylim([0, 100])
    # # plt.xlabel(f"Variable {var_idx} (raw units)")
    # plt.xlabel(features[var_idx])
    # ylabel = "φ(x) per location"
    # if plot_scaled and normalize_y and (y_std is not None):
    #     ylabel += " [approx. target units]"
    # plt.ylabel(ylabel)
    # plt.title("φ vs. feature (valid samples only)")

    # # optional binned median trend
    # if overlay_median:
    #     x_np = x_feat.cpu().numpy()
    #     y_np = phi_vals.cpu().numpy()
    #     q_edges = np.linspace(0.0, 1.0, num_bins + 1)
    #     edges = np.quantile(x_np, q_edges)
    #     edges = np.unique(edges)
    #     if len(edges) > 2:
    #         bins = np.digitize(x_np, edges[1:-1], right=True)
    #         med_x, med_y = [], []
    #         for b in range(len(edges) - 1):
    #             mask = bins == b
    #             if mask.any():
    #                 med_x.append(np.median(x_np[mask]))
    #                 med_y.append(np.median(y_np[mask]))
    #         if med_x:
    #             plt.plot(med_x, med_y, "--", color="red")

    # plt.tight_layout()
    # plt.show()

    plot_kwargs = dict(
        title="φ vs. feature (valid samples only)",
        height=700,
        width=1200,
        # opacity=0.3,
        # trendline="lowess",
        # trendline_options=dict(frac=0.1),
    )
    if color_idx is not None:
        df_plot = (
            pl.DataFrame(
                {"x": x_feat.cpu(), "phi": phi_vals.cpu(), "color": color_feat}
            )
            .with_columns(
                color=pl.col("color").qcut(4, labels=[f"q{k}" for k in range(4)])
            )
            .sort("color")
        )
        # fig = px.scatter(df_plot, "x", "phi", color="color", **plot_kwargs)
        fig = px.box(df_plot, "x", "phi", color="color", **plot_kwargs)
    else:
        df_plot = pl.DataFrame({"x": x_feat.cpu(), "phi": phi_vals.cpu()})
        fig = px.scatter(df_plot, "x", "phi", **plot_kwargs)
    fig.show()
    print(
        f"Sampled {K} / {M} valid pairs (after filtering {missing_idx} != {missing_value})."
    )


scatter_phi_vs_feature_from_data(
    ckpt_path=ckpt_path,
    data_path=data_path,
    var_idx=-6,
    color_idx=1,
    num_pairs=50000,
    drop_missing=True,  # <- enable filtering
    missing_idx=1,  # <- feature index that indicates missing
    missing_value=0.0,  # <- treat 0 as missing
    plot_scaled=False,
    overlay_median=True,
    num_bins=30,
    figsize=(10, 7),
)


In [None]:
encoded = xr.open_dataset("data/dataset_all_zones.zarr")
ds = cfxr.decode_compress_to_multi_index(encoded, "forecast_index")
ds

In [None]:
ds_newest_time_ref = ds.time_ref.max()
ds_newest = ds.sel(time_ref=ds_newest_time_ref)
X_newest = torch.from_numpy(ds_newest["X"].values)
y_newest = torch.from_numpy(ds_newest["y"].values.astype(np.float32))
time = ds_newest.time.values
bidding_area = ds_newest.bidding_area.values

x_mean = ckpt.get("x_mean", None)
x_std = ckpt.get("x_std", None)
X_norm = (X_newest - x_mean) / x_std

In [None]:
bidding_area = "ELSPOT NO3"
lt = 1
n_timesteps = 100

ds_newest = ds.sel(bidding_area=bidding_area, lt=lt).isel(
    forecast_index=slice(-n_timesteps, None)
)
X_newest = torch.from_numpy(ds_newest["X"].values)
y_newest = torch.from_numpy(ds_newest["y"].values.astype(np.float32))
time = ds_newest.time.values
bidding_area = ds_newest.bidding_area.values

x_mean = ckpt.get("x_mean", None)
x_std = ckpt.get("x_std", None)
X_norm = (X_newest - x_mean) / x_std

In [None]:
with torch.no_grad():
    preds = model(X_norm)

In [None]:
import polars as pl

df_eval = pl.DataFrame(
    {
        "y_true": y_newest,
        "y_pred": preds,
        "time": time,
        "bidding_area": bidding_area,
    }
)
px.line(
    df_eval.unpivot(index=["time", "bidding_area"]),
    "time",
    "value",
    color="bidding_area",
    line_dash="variable",
)

# df_eval = pl.DataFrame(
#     {
#         "y_true": y_newest,
#         "y_pred": preds,
#         "time": time,
#     }
# )
# px.line(
#     df_eval.unpivot(index=["time"]),
#     "time",
#     "value",
#     color="variable",
# )

In [None]:
df_eval.pivot("bidding_area", index="time", values="y_pred").select(
    "time", "ELSPOT NO1", "ELSPOT NO2", "ELSPOT NO3", "ELSPOT NO4"
)

In [None]:
df_eval.group_by("bidding_area").agg(
    RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt()
).sort("bidding_area")

In [None]:
torch.sqrt(((preds - y_plot) ** 2).mean()).item()