In [None]:
# TODO:
#    1. early stopping not working optuna experiment

In [None]:
import os
import shutil
import pickle
import random
import sys
import numpy as np
import pandas as pd
import duckdb
from typing import List

import requests
from io import StringIO

import ibis
import ibis.selectors as s
from ibis import _
ibis.options.interactive = True

from sklearn.preprocessing import RobustScaler

import torch

from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import (
    Scaler,
    MissingValuesFiller,
    Mapper,
    InvertibleMapper,
)
from darts.dataprocessing import Pipeline
from darts.metrics import mape, smape, mae, ope, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, IceCreamHeaterDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression, GumbelLikelihood, GaussianLikelihood

from darts import TimeSeries
from darts.utils.timeseries_generation import (
    gaussian_timeseries,
    linear_timeseries,
    sine_timeseries,
)
from darts.models import (
    TFTModel,
    TiDEModel,
    DLinearModel,
    NLinearModel,
    TSMixerModel
)


from torchmetrics import (
    SymmetricMeanAbsolutePercentageError, 
    MeanAbsoluteError, 
    MeanSquaredError,
)

from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import mlflow

import warnings
warnings.filterwarnings("ignore")

# logging
import logging

# define log
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

In [None]:
# import optuna
# from optuna.integration import PyTorchLightningPruningCallback
# from optuna.visualization import (
#     plot_optimization_history,
#     plot_contour,
#     plot_param_importances,
#     plot_pareto_front,
# )

In [None]:
## will be loaded from root when deployed
# from darts_wrapper import DartsGlobalModel

In [None]:
os.chdir('../..')

In [None]:
# custom modules
import src.data_engineering as de
from src import params
from src import plotting
from src import utils
from src.modeling import get_ci_err, build_fit_tsmixerx, log_pretty

In [None]:
log.info(f'FORECAST_HORIZON: {params.FORECAST_HORIZON}')
log.info(f'INPUT_CHUNK_LENGTH: {params.INPUT_CHUNK_LENGTH}')

## Load model

In [None]:
# mlflow.set_tracking_uri("sqlite:///mlruns.db")
log.info(f'mlflow.get_tracking_uri(): {mlflow.get_tracking_uri()}')
exp_name = 'spp_weis'
exp = mlflow.get_experiment_by_name(exp_name)
exp

In [None]:
runs = mlflow.search_runs(
    experiment_ids = exp.experiment_id,
    # order_by=['metrics.test_mae']
    order_by=['end_time']
    )

runs.sort_values('end_time', ascending=False, inplace=True)
runs.head()

In [None]:
best_run_id = runs.run_id.iloc[0]
best_run_id

In [None]:
runs['artifact_uri'].iloc[0]

In [None]:
model_path = runs['artifact_uri'].iloc[0] + '/GlobalForecasting'

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_path)

## Prep data

In [None]:
con = ibis.duckdb.connect("data/spp.ddb", read_only=True)
all_df_pd = de.all_df_to_pandas(de.prep_all_df(con))
lmp = de.prep_lmp(con)
lmp_pd_df = (
    lmp
    .to_pandas()
    .set_index('timestamp_mst')
)
con.disconnect()

In [None]:
con = ibis.duckdb.connect("data/spp.ddb", read_only=True)

In [None]:
lmp = con.table('lmp')
lmp

In [None]:
lmp.timestamp_mst_HE.max()

In [None]:
# all_df_pd = de.all_df_to_pandas(de.prep_all_df())

In [None]:
# lmp = de.prep_lmp()

In [None]:
# lmp_pd_df = (
#     lmp
#     .to_pandas()
#     .set_index('timestamp_mst')
# )

In [None]:
lmp_pd_df.index.max()

In [None]:
# these will be values selected by user
lmp_pd_df.unique_id.unique()

In [None]:
plot_node_name = 'PSCO_PRPM_PR'
idx = lmp_pd_df.unique_id == 'PSCO_PRPM_PR'
price_df = lmp_pd_df[idx]
price_df

In [None]:
idx = all_df_pd.unique_id == 'PSCO_PRPM_PR'
node_all_df_pd = all_df_pd[idx]
node_all_df_pd

In [None]:
plot_series = de.get_all_series(price_df)[0]

In [None]:
plot_series.plot()

In [None]:
plot_series.static_covariates.unique_id.LMP

In [None]:
plot_series.end_time()

In [None]:
future_cov_series = de.get_futr_cov(node_all_df_pd)[0]
future_cov_series.plot()

In [None]:
future_cov_series.end_time()

In [None]:
past_cov_series = de.get_past_cov(node_all_df_pd)[0]
past_cov_series.plot()

## Test plotting

In [None]:
# selected by user
forecast_start = pd.Timestamp('2024-07-31') + pd.Timedelta('1h')
forecast_start = plot_series.end_time()  - pd.Timedelta('72h')

In [None]:
node_series = plot_series.drop_after(forecast_start)
log.info(f'node_series.end_time(): {node_series.end_time()}')
# future_cov_series = futr_cov
# past_cov_series = past_cov

In [None]:
data = {
    'series': [node_series.to_json()],
    'past_covariates': [past_cov_series.to_json()],
    'future_covariates': [future_cov_series.to_json()],
    'n': params.FORECAST_HORIZON,
    'num_samples': 200
}
df = pd.DataFrame(data)

plot_cov_df = future_cov_series.pd_dataframe()
plot_cov_df = (
    plot_cov_df
    .reset_index()
    .rename(columns={'timestamp_mst':'time', 're_ratio': 'Ratio'})
)

# Predict on a Pandas DataFrame.
df['num_samples'] = 500
pred = loaded_model.predict(df)
preds = TimeSeries.from_json(pred)

q_df = plotting.get_quantile_df(preds)

# plot_df = plotting.get_mean_df(preds).merge(
#     plotting.get_quantile_df(preds),
#     left_index=True,
#     right_index=True,
# )

lmp_df = lmp.to_pandas().rename(
    columns={
        'LMP': 'LMP_HOURLY',
        'unique_id':'node', 
        'timestamp_mst':'time'
    })

plot_df = plotting.get_plot_df(
        TimeSeries.from_json(pred),
        plot_cov_df,
        lmp_df,
        plot_node_name,
    )
plot_df.rename(columns={'mean':'mean_fcast'}, inplace=True)
plot_df

plotting.plotly_forecast(plot_df, plot_node_name, show_fig=False)



In [None]:
df