In [None]:
import joblib
from collections import defaultdict
from datetime import datetime

import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
dump = joblib.load("../dump.joblib")

In [None]:
batch, _, _, y_hat = dump.values()

In [None]:
y_hat = y_hat.reshape(-1, y_hat.shape[-1])
y_hat.shape

In [None]:
pv_t0 = [b["pv_t0_idx"] for b in batch]
pv_t0_idx = pv_t0[0]
assert all(p == pv_t0_idx for p in pv_t0)

In [None]:
batch = {
    key: np.squeeze(np.array([b[key] for b in batch]).reshape(y_hat.shape[0], -1))
    for key in batch[0].keys()
    if key != "pv_t0_idx"
}

In [None]:
for k, v in batch.items():
    print(f"{k:<25}{v.shape}")

In [None]:
y = batch["pv"][:, pv_t0_idx:]
data = np.hstack((y, y_hat))

df = pd.DataFrame(data)

r = range(y.shape[-1])
df.columns = [f"true{x*15}" for x in r] + [f"pred{x*15}" for x in r]

df = df.assign(
    system=batch["pv_system_row_number"],
    datetime=[datetime.utcfromtimestamp(ts[pv_t0_idx]) for ts in batch["pv_time_utc"]],
)

df = df.sort_values(by=["datetime", "system"])
df = df.drop_duplicates(subset=["system", "datetime"])

print(len(df))
df.head()

In [None]:
lo = pd.wide_to_long(
    df,
    stubnames=["true", "pred"],
    i=["system", "datetime"],
    j="step",
).reset_index()
lo = lo.assign(
    abs_err=(lo["pred"] - lo["true"]).abs(),
    squ_err=(lo["pred"] - lo["true"]) ** 2,
)
lo.head()

In [None]:
# lo.loc[(lo.system == 2) & (lo.datetime == lo.iloc[0].datetime)]

In [None]:
sns.lineplot(
    data=lo,
    x="step",
    y="abs_err",
    hue="system",
)