In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime as dt
import os
import pathlib
import importlib

import shap
import pandas as pd
import altair as alt
import numpy as np
import pvlib
from psp.serialization import load_model
from psp.data.data_sources.pv import NetcdfPvDataSource
from psp.data.data_sources.nwp import NwpDataSource
from psp.typings import X
from psp.visualization import plot_sample
from psp.metrics import mean_absolute_error

import plotly.express as px

alt.data_transformers.disable_max_rows()


def _(df, *args, **kwargs):
    print(len(df))
    display(df.head(*args, **kwargs))

In [None]:
# It's always annoying to set the working directory: we use an environment variable defined in the Makefile.
CWD = os.environ.get("CWD")
if CWD:
    os.chdir(CWD)

In [None]:
EXPS = [
    "island-v3",
]

In [None]:
models = [f"exp_results/{exp}/model.pkl" for exp in EXPS]

In [None]:
def load_testset(exp, split="test"):
    path = f"exp_results/{exp}/{split}_errors.csv"
    testset = pd.read_csv(path, dtype={"pv_id": str}, parse_dates=["ts"])
    testset["pred_ts"] = testset["ts"] + pd.to_timedelta(testset["horizon"], unit="minute")
    testset["exp"] = exp
    return testset


testset = pd.concat([load_testset(exp) for exp in EXPS])
_(testset)

In [None]:
# Note that we are loading the ground truth from the first config in the case where we have many models.
first_exp_config = importlib.import_module(".config", f"exp_results.{EXPS[0]}").ExpConfig()
pv_ds = first_exp_config.get_pv_data_source()

In [None]:
models = {exp: load_model(model) for exp, model in zip(EXPS, MODELS)}
for name, model in models.items():
    exp_config = importlib.import_module(".config", f"exp_results.{name}").ExpConfig()
    model.set_data_sources(**exp_config.get_data_source_kwargs())

In [None]:
testset = testset.sort_values("error", ascending=False)

In [None]:
_(testset, 20)

In [None]:
SAMPLE_IDX = -200

from psp.visualization import find_horizon_index
import shap

shap.initjs()

test_row = testset[
    (testset["horizon"] < 60 * 12) & (testset["y"] > 0.1) & (testset["pv_id"] != "7759")
].iloc[SAMPLE_IDX]

print("test row")
print(test_row)
print()
row = dict(test_row)
ts = row["ts"]
# horizon_idx = 0
horizon = test_row["horizon"]

# print(horizon_idx)
# horizon_idx = 10 * 60 // 15
pv_id = row["pv_id"]

# Tweak the sample
# ts = ts + dt.timedelta(days=2)
# horizon_idx = 24 * 4
# horizon = 60 * 7
# Hack the sample
# pv_id = '27000'
# ts = dt.datetime(2021, 7, 31, 12, 30)
# ts= row['ts'] + dt.timedelta(days=9.5) + dt.timedelta(minutes=60)
# horizon_idx = 0

horizon_idx = find_horizon_index(horizon, model.config.horizons)

plot_sample(
    x=X(pv_id=pv_id, ts=ts),
    horizon_idx=horizon_idx,
    models=models,
    pv_data_source=pv_ds,
    nwp_data_source=nwp_ds,
    #     meta=meta,
)

In [None]:
xs = [X(pv_id=row["pv_id"], ts=row["ts"]) for _, row in testset.sample(20).iterrows()]

for model_name, model in models.items():
    print(model_name)
    explanations = [model.explain(x) for x in xs]
    explanation = sum(explanations) / len(explanations)

    display(shap.plots.bar(explanation))