In [None]:
%cd ~/projects/wind/

In [None]:
import polars as pl
import numpy as np
import pymc as pm
import plotly.express as px
from datetime import datetime, timedelta
import pytensor.tensor as pt
import plotly.graph_objs as go

In [None]:
bidding_area = "ELSPOT NO4"
dataset_path = "data/windpower_area_dataset.parquet"
val_cutoff = datetime(2025, 1, 1)
data_train = pl.scan_parquet(dataset_path).filter(
    pl.col("bidding_area") == bidding_area,
    pl.col("time_ref") >= datetime(2024, 1, 1),
    pl.col("time_ref") < val_cutoff,
    pl.col("time") >= pl.col("time_ref").dt.date() + timedelta(days=1),
    pl.col("time") < pl.col("time_ref").dt.date() + timedelta(days=2),
)
data_val = pl.scan_parquet(dataset_path).filter(
    pl.col("bidding_area") == bidding_area,
    pl.col("time_ref") >= val_cutoff,
    pl.col("time") >= pl.col("time_ref").dt.date() + timedelta(days=1),
    pl.col("time") < pl.col("time_ref").dt.date() + timedelta(days=2),
)

In [None]:
max_lt = 48


def get_emos_features(data):
    X_mu = (
        data.select(
            pl.lit(1).alias("intercept"),
            "mean_sum_pred",
            "min_sum_pred",
            "max_sum_pred",
            "last_power",
            "recent_mean",
            "ramp",
            "recent_max",
            "recent_min",
            "sin_hod",
            "cos_hod",
            "sin_doy",
            "cos_doy",
        )
        .cast(pl.Float32)
        .collect()
        .to_numpy()
    )

    X_sigma = (
        data.select(
            pl.lit(1).alias("intercept"),
            pl.col("std_sum_pred").log().alias("log_std_sum_pred"),
            (pl.col("max_sum_pred") - pl.col("min_sum_pred"))
            .log()
            .alias("log_max_sum_pred"),
            pl.col("recent_std").log().alias("log_recent_std"),
            pl.col("ramp").abs().log().alias("log_ramp"),
            "sin_hod",
            "cos_hod",
            "sin_doy",
            "cos_doy",
            (pl.col("lt") / max_lt).alias("lt"),
        )
        .cast(pl.Float32)
        .collect()
        .to_numpy()
    )

    y = data.select("relative_power").collect().to_numpy()[:, 0]
    sample_weight = data.select("operating_power_max").collect().to_numpy()[:, 0]
    return X_mu, X_sigma, y, sample_weight


X_mu_train, X_sigma_train, y_train, sample_weight_train = get_emos_features(data_train)
X_mu_val, X_sigma_val, y_val, sample_weight_val = get_emos_features(data_val)

In [None]:
def fit_emos_pymc(
    X_mu,
    X_sigma,
    y,
    draws=1500,
    tune=1500,
    target_accept=0.9,
    chains=4,
    seed=123,
):
    with pm.Model() as model:
        Xmu = pm.Data("Xmu", X_mu)
        Xsig = pm.Data("Xsig", X_sigma)
        y_obs = pm.Data("y_obs", y)

        # Priors
        beta = pm.Normal("beta", mu=0.0, sigma=2.0, shape=X_mu.shape[1])
        gamma = pm.Normal("gamma", mu=0.0, sigma=0.2, shape=X_sigma.shape[1])

        # Linear predictors
        mu = pm.Deterministic("mu", pt.dot(Xmu, beta))  # logit-mean
        log_sig = pm.Deterministic("log_sigma", pt.dot(Xsig, gamma))
        sigma = pm.Deterministic("sigma", pm.math.exp(log_sig))  # > 0

        # Likelihood on capacity factor (0,1) via LogitNormal
        y_like = pm.LogitNormal(
            "y", mu=mu, sigma=sigma, observed=y_obs
        )  # bounds handled by dist

        idata = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            # target_accept=target_accept,
            random_seed=seed,
            progressbar=True,
            nuts_sampler="nutpie",
            nuts_sampler_kwargs=dict(backend="jax"),
        )
    return model, idata

In [None]:
model, idata = fit_emos_pymc(
    X_mu_train,
    X_sigma_train,
    y_train,
    draws=1000,
    tune=1000,
    chains=4,
)

In [None]:
pm.model_to_graphviz(model)

In [None]:
import arviz as az
import matplotlib.pyplot as plt

ax = az.plot_trace(idata, var_names="beta", compact=False)
plt.tight_layout()

In [None]:
with model:
    prior_pred = pm.sample_prior_predictive(draws=10)

In [None]:
px.histogram(
    pl.DataFrame(
        prior_pred.prior_predictive["y"].stack(sample=("chain", "draw")).values,
        schema=[str(k) for k in range(10)],
    ).unpivot(),
    "value",
)

In [None]:
px.line(
    pl.concat(
        [
            pl.DataFrame(
                prior_pred.prior_predictive["y"].stack(sample=("chain", "draw")).values,
                schema=[str(k) for k in range(10)],
            ),
            data_train.select("time", "time_ref", "lt").collect(),
        ],
        how="horizontal",
    )
    .unpivot(index=["time", "time_ref", "lt"])
    .filter(pl.col("time") < pl.col("time_ref").dt.date() + timedelta(days=2)),
    "time",
    "value",
    color="variable",
)

In [None]:
def predict_emos_quantiles(
    model, idata, X_mu_new, X_sigma_new, q=(0.025, 0.05, 0.5, 0.95, 0.975)
):
    with model:
        pm.set_data(
            {
                "Xmu": X_mu_new,
                "Xsig": X_sigma_new,
                "y_obs": np.zeros(X_mu_new.shape[0], dtype=np.float32),
            }
        )
        ppc = pm.sample_posterior_predictive(idata)

    # ppc["y"] has shape (n_draws*chains, n_obs)
    posterior = ppc.posterior_predictive["y"].stack(sample=("chain", "draw"))
    samples = posterior.transpose("sample", "y_dim_0").values
    # Quantiles per observation
    qs = np.quantile(samples, q, axis=0).T
    out = pl.DataFrame(qs, schema=[f"q{int(1000 * qq):03d}" for qq in q]).with_columns(
        pred_mean=samples.mean(axis=0),
        pred_std=samples.std(axis=0, ddof=1),
    )
    return (
        out,
        posterior,
    )  # samples  # you can keep samples for scenario generation / ECC later

In [None]:
out, posterior = predict_emos_quantiles(model, idata, X_mu_val, X_sigma_val)

In [None]:
out

In [None]:
df_plot = pl.read_csv("data/quantile_pred.csv").filter(
    pl.col("bidding_area") == "ELSPOT NO3"
)
df_plot

In [None]:
# df_plot = pl.concat([data_val.collect(), out], how="horizontal").filter(
#     pl.col("time") >= pl.col("time_ref").dt.date() + timedelta(days=1),
#     pl.col("time") < pl.col("time_ref").dt.date() + timedelta(days=2),
#     # pl.col("time_ref") == datetime(2024, 4, 1, 9)
# )

fig = go.Figure(
    [
        go.Scatter(
            name="y_true",
            x=df_plot["time"],
            y=df_plot["relative_power"],
            mode="lines",
            line=dict(color="rgb(237, 55, 31)"),
        ),
        go.Scatter(
            name="y_pred",
            x=df_plot["time"],
            y=df_plot["pred_mean"],
            mode="lines",
            line=dict(color="rgb(31, 119, 180)"),
        ),
        go.Scatter(
            name="Upper Bound alpha-10",
            x=df_plot["time"],
            y=df_plot["q950"],
            mode="lines",
            marker=dict(color="#444"),
            line=dict(width=0),
            showlegend=False,
        ),
        go.Scatter(
            name="Lower Bound alpha-10",
            x=df_plot["time"],
            y=df_plot["q050"],
            marker=dict(color="#444"),
            line=dict(width=0),
            mode="lines",
            fillcolor="rgba(68, 68, 68, 0.3)",
            fill="tonexty",
            showlegend=False,
        ),
        go.Scatter(
            name="Upper Bound alpha-5",
            x=df_plot["time"],
            y=df_plot["q975"],
            mode="lines",
            marker=dict(color="#444"),
            line=dict(width=0),
            showlegend=False,
        ),
        go.Scatter(
            name="Lower Bound alpha-5",
            x=df_plot["time"],
            y=df_plot["q025"],
            marker=dict(color="#444"),
            line=dict(width=0),
            mode="lines",
            fillcolor="rgba(68, 68, 68, 0.1)",
            fill="tonexty",
            showlegend=False,
        ),
    ]
)
# [fig.add_vline(x=x) for x in df_plot["time_ref"].unique()]
fig.update_layout(
    yaxis=dict(title=dict(text="Power")),
    title=dict(text="Continuous, variable value error bars"),
    hovermode="x",
)
fig.show()

In [None]:
df_plot.with_columns(
    # under=pl.col("relative_power") < pl.col("q050"),
    # over=pl.col("relative_power") > pl.col("q950"),
    under=pl.col("relative_power") < pl.col("q025"),
    over=pl.col("relative_power") > pl.col("q975"),
).select(
    pl.col("under").mean(),
    (~pl.col("under") & ~pl.col("over")).mean().alias("in"),
    pl.col("over").mean(),
)

In [None]:
import plotly.express as px

px.histogram(
    data_train.collect(),
    "relative_power",
    color="bidding_area",
    barmode="overlay",
)

In [None]:
def per_observation_crps(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    if y_pred.shape[1:] != (1,) * (y_pred.ndim - y_true.ndim - 1) + y_true.shape:
        raise ValueError(
            f"""Expected y_pred to have one extra sample dim on left.
                Actual shapes: {y_pred.shape} versus {y_true.shape}"""
        )

    absolute_error = np.mean(np.abs(y_pred - y_true), axis=0)

    num_samples = y_pred.shape[0]
    if num_samples == 1:
        return absolute_error

    y_pred = np.sort(y_pred, axis=0)
    diff = y_pred[1:] - y_pred[:-1]
    weight = np.arange(1, num_samples) * np.arange(num_samples - 1, 0, -1)
    weight = weight.reshape(weight.shape + (1,) * (diff.ndim - 1))

    return absolute_error - np.sum(diff * weight, axis=0) / num_samples**2


def crps(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    sample_weight: np.ndarray | None = None,
):
    return np.average(per_observation_crps(y_true, y_pred), weights=sample_weight)

In [None]:
from sklearn.metrics import root_mean_squared_error

capacity = data_val.select("operating_power_max").collect().to_series().to_numpy()
y_true = y_val * capacity
y_pred = posterior.transpose("sample", "y_dim_0").values * capacity

print(f"CRPS: {crps(y_true, y_pred)}")
print(f"RMSE: {root_mean_squared_error(y_true, np.mean(y_pred, axis=0))}")