In [None]:
from typing import Tuple

import cf_xarray as cfxr
import matplotlib.pyplot as plt
import plotly.express as px
import polars as pl
import torch
import xarray as xr
import numpy as np
from model import load_model_checkpoint
from prepare_data import FEATURES

In [None]:
features = [*FEATURES, "local_power_pred"]
features

In [None]:
checkpoint_path = "checkpoints/wind_masked_last.pth"
data_path = "data/torch_dataset_all_zones.pt"
var_idx = 10
n_points = 200
plot_scaled = False

In [None]:
model, checkpoint = load_model_checkpoint(checkpoint_path)

In [None]:
encoded = xr.open_dataset("data/dataset_all_zones.zarr")
ds = cfxr.decode_compress_to_multi_index(encoded, "forecast_index")
ds

## Pred

In [None]:
def get_data(data_path, val_cutoff_date, features, bidding_area, device):
    encoded = xr.open_dataset("data/dataset_all_zones.zarr")
    ds = cfxr.decode_compress_to_multi_index(encoded, "forecast_index")
    ds = ds.isel(forecast_index=~np.isnan(ds["y"]))  # Drop samples where y is missing
    if bidding_area is not None:
        ds = ds.sel(bidding_area=bidding_area)
    if features is not None:
        X = torch.from_numpy(
            ds["X"].sel(feature=features).values.astype(np.float32)
        ).to(device)
        mask = torch.from_numpy(
            ds["X"].sel(feature="mask").values.astype(np.float32)
        ).to(device)
    else:
        X = torch.from_numpy(ds["X"].values.astype(np.float32))
        mask = X[..., -1].to(device)
        X = X[..., :-1].to(device)

    y = torch.from_numpy(ds["y"].values.astype(np.float32)).to(device)

    N, L, V = X.shape

    # Split
    sample_idx = torch.arange(N, device=device)  # No shuffle for time sertive eval
    val_split_date = np.datetime64(val_cutoff_date)
    train_idx = sample_idx[ds.time_ref.values < val_split_date]
    val_idx = sample_idx[ds.time_ref.values >= val_split_date]
    return ds, X, y, mask, train_idx, val_idx


eval_dfs = []
for i, bidding_area in enumerate(
    ["ELSPOT NO1", "ELSPOT NO2", "ELSPOT NO3", "ELSPOT NO4"]
):
    print(bidding_area)
    ds, X, y, mask, train_idx, val_idx = get_data(
        data_path, "2025-01-01", None, bidding_area, device="cpu"
    )
    local_preds = X[..., -1].sum(dim=-1)
    model, checkpoint = load_model_checkpoint(f"checkpoints/wind_NO{i + 1}_last.pth")
    x_mean = checkpoint.get("x_mean", None)
    x_std = checkpoint.get("x_std", None)
    X_norm = (X - x_mean) / x_std
    with torch.no_grad():
        preds = model(X_norm, mask)

    eval_dfs.append(
        pl.DataFrame(
            {
                "time_ref": ds.time_ref.values,
                "time": ds.time.values,
                "lt": ds.lt.values,
                "y_true": y,
                "y_pred": preds,
                "local_preds": local_preds,
                "bidding_area": bidding_area,
            }
        )
    )
df_eval = pl.concat(eval_dfs)
df_eval

In [None]:
df_eval.with_columns(
    subset=pl.when(pl.col("time_ref") < np.datetime64("2025-01-01"))
    .then(pl.lit("train"))
    .otherwise(pl.lit("val"))
).group_by("subset", "bidding_area").agg(
    RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt(),
    RMSE_local_pred=((pl.col("local_preds") - pl.col("y_true")) ** 2).mean().sqrt(),
).sort("subset", "bidding_area")

In [None]:
bidding_area = "ELSPOT NO3"
px.line(
    df_eval.filter(pl.col("lt") < 24, pl.col("bidding_area") == bidding_area).unpivot(
        index=["time_ref", "time", "lt", "bidding_area"]
    ),
    "time",
    "value",
    color="variable",
)

## RMSE

In [None]:
N = ds["y"].shape[0]
val_idx = int(0.8 * N)

ds_newest = ds.isel(forecast_index=slice(val_idx, None))
X_newest = torch.from_numpy(ds_newest["X"].values)
mask = X_newest[..., -1]
X_newest = X_newest[..., :-1]
y_newest = torch.from_numpy(ds_newest["y"].values.astype(np.float32))
bidding_area = ds_newest.bidding_area.values

x_mean = checkpoint.get("x_mean", None)
x_std = checkpoint.get("x_std", None)
X_norm = (X_newest - x_mean) / x_std

with torch.no_grad():
    preds = model(X_norm, mask)

local_preds = ds_newest["X"].sel(feature="local_power_pred").sum(dim="station").values

pl.DataFrame(
    {
        "y_true": y_newest,
        "y_pred": preds,
        "local_preds": local_preds,
        "bidding_area": bidding_area,
    }
).filter(pl.col("y_true").is_not_nan()).group_by("bidding_area").agg(
    RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt(),
    RMSE_local_pred=((pl.col("local_preds") - pl.col("y_true")) ** 2).mean().sqrt(),
).sort("bidding_area")

In [None]:
# ds_newest_time_ref = np.datetime64("2024-10-21T09:00:00")#s.time_ref.max()
ds_newest_time_ref = ds.time_ref.max()
ds_newest = ds.sel(time_ref=ds_newest_time_ref)
X_newest = torch.from_numpy(ds_newest["X"].values)
mask = X_newest[..., -1]
X_newest = X_newest[..., :-1]
y_newest = torch.from_numpy(ds_newest["y"].values.astype(np.float32))
time = ds_newest.time.values
bidding_area = ds_newest.bidding_area.values

x_mean = checkpoint.get("x_mean", None)
x_std = checkpoint.get("x_std", None)
X_norm = (X_newest - x_mean) / x_std

In [None]:
with torch.no_grad():
    preds = model(X_norm, mask)

local_preds = ds_newest["X"].sel(feature="local_power_pred").sum(dim="station").values

df_eval = pl.DataFrame(
    {
        "y_true": y_newest,
        "y_pred": preds,
        "local_preds": local_preds,
        "time": time,
        "bidding_area": bidding_area,
    }
)

px.line(
    df_eval.unpivot(index=["time", "bidding_area"]),
    "time",
    "value",
    color="bidding_area",
    line_dash="variable",
)

# px.line(
#     df_eval.group_by("time").agg(pl.col("y_true").sum(), pl.col("y_pred").sum(), pl.col("local_preds").sum()).unpivot(index=["time"]).sort("time"),
#     "time",
#     "value",
#     color="variable",
# )

In [None]:
df_eval.group_by("bidding_area").agg(
    RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt(),
    RMSE_local_pred=((pl.col("local_preds") - pl.col("y_true")) ** 2).mean().sqrt(),
).sort("bidding_area")

In [None]:
df_eval.select(
    RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt(),
    RMSE_local_pred=((pl.col("local_preds") - pl.col("y_true")) ** 2).mean().sqrt(),
)

In [None]:
bidding_area = "ELSPOT NO4"
N = ds["y"].shape[0]
val_idx = int(0.8 * N)

ds_newest = ds.isel(forecast_index=slice(val_idx, None)).sel(
    bidding_area=bidding_area, lt=slice(None, 24)
)
# ds_newest = ds.sel(bidding_area=bidding_area)
X_newest = torch.from_numpy(ds_newest["X"].values)
mask = X_newest[..., -1]
X_newest = X_newest[..., :-1]
y_newest = torch.from_numpy(ds_newest["y"].values.astype(np.float32))
time = ds_newest.time.values
bidding_area = ds_newest.bidding_area.values
time_ref = ds_newest.time_ref.values

x_mean = checkpoint.get("x_mean", None)
x_std = checkpoint.get("x_std", None)
X_norm = (X_newest - x_mean) / x_std

In [None]:
with torch.no_grad():
    preds = model(X_norm, mask)

df_eval = pl.DataFrame(
    {
        "y_true": y_newest,
        "y_pred": preds,
        # "local_preds": local_preds,
        "time": time,
        "time_ref": time_ref,
    }
).with_columns(lt=(pl.col("time") - pl.col("time_ref")).dt.total_hours())

In [None]:
local_preds = ds_newest["X"].sel(feature="local_power_pred").sum(dim="station").values

fig = px.line(
    df_eval.with_columns(local_preds=local_preds).unpivot(
        index=["time", "time_ref", "lt"]
    ),
    "time",
    "value",
    color="variable",
    hover_data=["lt"],
)
# fig.update_layout(hovermode="x unified")
fig.show()

In [None]:
import plotly.graph_objects as go

df_plot = df_eval.group_by("time").agg(pl.col("y_true").first()).sort("time")

fig = px.scatter(df_eval, "time", "y_pred", color="lt")
fig.add_trace(go.Scatter(x=df_plot["time"], y=df_plot["y_true"], name="y_true"))
fig

In [None]:
df_eval.select(RMSE=((pl.col("y_pred") - pl.col("y_true")) ** 2).mean().sqrt())