In [None]:
# Sample analysis for the RecentHistory model
# TODO: Could we remove the recent-history specific stuff?

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime as dt
import os
import pathlib

import shap
import pandas as pd
import altair as alt
import numpy as np
import pvlib
from psp.serialization import load_model
from psp.data.data_sources.pv import NetcdfPvDataSource
from psp.data.data_sources.nwp import NwpDataSource
from psp.typings import X
from psp.visualization import plot_sample
from psp.metrics import mean_absolute_error

# TODO This should not be specific to the RecentHistory model
from psp.models.recent_history import SetupConfig

import plotly.express as px

alt.data_transformers.disable_max_rows()


def _(df, *args, **kwargs):
    print(len(df))
    display(df.head(*args, **kwargs))

In [None]:
# It's always annoying to set the working directory: we use an environment variable defined in the Makefile.
CWD = os.environ.get("CWD")
if CWD:
    os.chdir(CWD)

In [None]:
%pwd

In [None]:
EXP = "deleteme"
MODEL = f"exp_results/{EXP}/model.pkl"
TESTSET = f"exp_results/{EXP}/errors.csv"
PV_DATA = "data/5min.netcdf"
NWP_DATA = (
    "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_3.zarr"
)
META = "data/metadata_sensitive.csv"
META_INFERRED = "data/meta_inferred.csv"

In [None]:
testset = pd.read_csv(TESTSET, dtype={"pv_id": str}, parse_dates=["ts"])
testset["pred_ts"] = testset["ts"] + pd.to_timedelta(testset["horizon"], unit="minute")
_(testset)
# print(testset.dtypes)

In [None]:
pv_ds = NetcdfPvDataSource(PV_DATA)
nwp_ds = NwpDataSource(NWP_DATA)

In [None]:
model = load_model(MODEL)
model.setup(SetupConfig(pv_data_source=pv_ds, nwp_data_source=nwp_ds))

In [None]:
meta = pd.read_csv(META, dtype={"ss_id": str}).set_index("ss_id")
meta_inferred = pd.read_csv(META_INFERRED, dtype={"ss_id": str}).set_index("ss_id")
_(meta)
_(meta_inferred)

In [None]:
testset = testset.sort_values("error", ascending=False)

In [None]:
_(testset, 20)

In [None]:
from psp.visualization import find_horizon_index
import shap

shap.initjs()

test_row = testset[(testset["horizon"] < 60 * 24.0) & (testset["y"] > 10)].iloc[30]
print("test row")
print(test_row)
print()
row = dict(test_row)
ts = row["ts"]
horizon_idx = find_horizon_index(row["horizon"], model.config.horizons)
# horizon_idx = 10 * 60 // 15
pv_id = row["pv_id"]

# Tweak the sample
# ts = ts + dt.timedelta(days=30 * 6 + 7)
# horizon_idx = 10

# Hack the sample
# pv_id = '27000'
# ts = dt.datetime(2021, 7, 31, 12, 30)
# ts= row['ts'] + dt.timedelta(days=9.5) + dt.timedelta(minutes=60)
# horizon_idx = 0

plot_sample(
    x=X(pv_id=pv_id, ts=ts),
    horizon_idx=horizon_idx,
    model=model,
    pv_data_source=pv_ds,
    nwp_data_source=nwp_ds,
    meta=meta,
    do_nwp=False,
)