In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import datetime as dt
import os
import pathlib
import importlib

import shap
import pandas as pd
import altair as alt
import numpy as np
import pvlib
from psp.serialization import load_model
from psp.data_sources.pv import NetcdfPvDataSource
from psp.data_sources.nwp import NwpDataSource
from psp.typings import X
from psp.visualization import plot_sample
from psp.metrics import mean_absolute_error
from psp.models.multi import MultiPvSiteModel

import plotly.express as px

alt.data_transformers.disable_max_rows()


def _(df, *args, **kwargs):
    print(len(df))
    display(df.head(*args, **kwargs))

In [None]:
# It's always annoying to set the working directory: we use an environment variable defined in the Makefile.
CWD = os.environ.get("CWD")
if CWD:
    os.chdir(CWD)

In [None]:
EXP_NAMES = ["exc_t_nwpF_excF",
            "exc_t_nwpT_excF",
            "exc_t_nwpF_excT",
            "exc_t_nwpT_excT"]

In [None]:
COLORS = [
    "#086788",
    "#4c9a8e",
    "#ff9736",
    "#ffd053",
    "#63bcaf",
    "#e4e4e4",
    "#ffac5f",
    "#7bcdf3",
    "#14120e",
]

In [None]:
EXP_ROOT = pathlib.Path("exp_results")

In [None]:
def load_testset(exp, split="test"):
    
    for ext in [".csv.gz", ".csv"]:
        try:
            path = f"{EXP_ROOT}/{exp}/{split}_errors{ext}"
            testset = pd.read_csv(path, dtype={"pv_id": str}, parse_dates=["ts"])
        except FileNotFoundError as e:
            # print(e)
            continue

    testset["pred_ts"] = testset["ts"] + pd.to_timedelta(testset["horizon"], unit="minute")
    testset["exp"] = exp
    return testset


testset = pd.concat([load_testset(exp) for exp in EXP_NAMES])
_(testset)

In [None]:
# Note that we are loading the ground truth from the first config in the case where we have many models.
first_exp_config = importlib.import_module(".config", f"{EXP_ROOT}.{EXP_NAMES[0]}").ExpConfig()
pv_ds = first_exp_config.get_pv_data_source()

In [None]:
# exp_name: model
models = {}

for name in EXP_NAMES:
    exp_config = importlib.import_module(".config", f"exp_results.{name}").ExpConfig()

    date_splits = exp_config.get_date_splits()  # .train_dates
    train_dates = [dt.train_date for dt in date_splits.train_date_splits]

    # Load the saved models for a given exp.
    model_list = [load_model(EXP_ROOT / name / f"model_{i}.pkl") for i in range(len(train_dates))]
    _models = {date: model for date, model in zip(train_dates, model_list)}
    # Wrap them into one big meta model.
    model = MultiPvSiteModel(_models)

    model.set_data_sources(**exp_config.get_data_source_kwargs())

    models[name] = model

In [None]:
testset = testset.sort_values("error", ascending=False)

In [None]:
_(testset, 20)

In [None]:
SAMPLE_IDX = 100

from psp.visualization import find_horizon_index
import shap

shap.initjs()

test_row = testset[
    # (testset["horizon"] == 60 * 4)
    # &
    #    (testset["horizon"] > 60 * 4) &
    # ((testset["ts"]).dt.hour == 0)
    # &
    #         ((testset['ts']).dt.hour < 12) &
    (testset["y"] > 0.1)
    & (testset["pv_id"] != "7759")
].iloc[SAMPLE_IDX]

print("test row")
print(test_row)
print()
row = dict(test_row)
ts = row["ts"]
# horizon_idx = 0
horizon = test_row["horizon"]

# print(horizon_idx)
# horizon_idx = 10 * 60 // 15
pv_id = row["pv_id"]

# Tweak the sample
# ts = dt.datetime(2022, 9, 27, 8)
# pred_ts = ts + dt.timedelta(hours=4)
# ts = ts + dt.timedelta(days=2)
# horizon_idx = 24 * 4
# horizon = round((pred_ts - ts).total_seconds() / 60.)
# print(horizon)
# Hack the sample
# pv_id = '27000'
# ts = dt.datetime(2021, 7, 31, 12, 30)
# ts= row['ts'] + dt.timedelta(days=9.5) + dt.timedelta(minutes=60)
# horizon_idx = 0

# Assume all the models use the same horizons.
horizons = first_exp_config.get_model_config().horizons
horizon_idx = find_horizon_index(horizon, horizons)

plot_sample(
    x=X(pv_id=pv_id, ts=ts),
    horizon_idx=horizon_idx,
    horizons=horizons,
    models=models,
    pv_data_source=pv_ds,
    nwp_data_source=None,
    colors=COLORS,
    resample_pv=True,
)

In [None]:
xs = [X(pv_id=row["pv_id"], ts=row["ts"]) for _, row in testset.sample(20).iterrows()]

for model_name, model in models.items():
    print(model_name)
    try:
        explanations = [model.explain(x) for x in xs]
    except Exception:
        continue
    explanation = sum(explanations) / len(explanations)

    display(shap.plots.bar(explanation))