In [None]:
# TODO:
#    1. early stopping not working optuna experiment

In [None]:
# TEST_MODEL = 'tft'
# TEST_MODEL = 'tsmixer'

RUN_EXP = True
NUM_TRIALS = 40

In [None]:
import os
import shutil
import pickle
import random
import sys
import numpy as np
import pandas as pd
import duckdb
from typing import List

import requests
from io import StringIO

import ibis
import ibis.selectors as s
from ibis import _
ibis.options.interactive = True

from sklearn.preprocessing import RobustScaler

from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import (
    Scaler,
    MissingValuesFiller,
    Mapper,
    InvertibleMapper,
)
from darts.dataprocessing import Pipeline
from darts.metrics import mape, smape, mae, ope, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, IceCreamHeaterDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression, GumbelLikelihood, GaussianLikelihood

from darts import TimeSeries
from darts.utils.timeseries_generation import (
    gaussian_timeseries,
    linear_timeseries,
    sine_timeseries,
)
from darts.models import (
    TFTModel,
    TiDEModel,
    DLinearModel,
    NLinearModel,
    TSMixerModel
)


from torchmetrics import (
    SymmetricMeanAbsolutePercentageError, 
    MeanAbsoluteError, 
    MeanSquaredError,
)

from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import warnings
warnings.filterwarnings("ignore")

# logging
import logging

# define log
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import (
    plot_optimization_history,
    plot_contour,
    plot_param_importances,
    plot_pareto_front,
)

In [None]:
## will be loaded from root when deployed
from darts_wrapper import DartsGlobalModel

In [None]:
os.chdir('../..')

In [None]:
# custom modules
import src.data_engineering as de
from src import plotting
from src import utils

## Data prep

In [None]:
# connect to database
con = ibis.duckdb.connect("data/spp.ddb")
con.list_tables()

In [None]:
lmp = con.table('lmp')
lmp

In [None]:
lmp = lmp.filter(_.Settlement_Location_Name.contains('PSCO'))

In [None]:
lmp.to_pandas()[['GMTIntervalEnd_HE', 'Settlement_Location_Name']].duplicated().sum()

In [None]:
drop_cols = [
    'Interval_HE', 'GMTIntervalEnd_HE', 'timestamp_mst_HE',
    'Settlement_Location_Name', 'PNODE_Name', 
    'MLC', 'MCC', 'MEC'
]

lmp = (
    lmp
    .mutate(unique_id = _.Settlement_Location_Name )
    .mutate(timestamp_mst = _.timestamp_mst_HE)
    # .mutate(y = _.LMP) 
    .drop(drop_cols) 
    .order_by(['unique_id', 'timestamp_mst'])
)

lmp

In [None]:
mtrf = con.table('mtrf')
mtrf

In [None]:
drop_cols = ['Interval', 'GMTIntervalEnd']

mtrf = (
    mtrf
    # .mutate(ds = _.timestamp_mst)
    .drop(drop_cols) 
    .order_by(['timestamp_mst'])
)

mtrf

In [None]:
mtlf = con.table('mtlf')
mtlf

In [None]:
drop_cols = ['Interval', 'GMTIntervalEnd',]

mtlf = (
    mtlf
    # .mutate(ds = _.timestamp_mst)
    .drop(drop_cols) 
    .order_by(['timestamp_mst'])
)

mtlf

In [None]:
all_df = (
    mtlf
    .left_join(mtrf, 'timestamp_mst')
    .select(~s.contains("_right")) # remove 'dt_right'
    .left_join(lmp, 'timestamp_mst')
    .select(~s.contains("_right")) # remove 'dt_right'
    .order_by(['unique_id', 'timestamp_mst'])
)
all_df 


In [None]:
all_df.describe()

In [None]:
all_df = (
    all_df
    .drop_null(['unique_id'])
    .mutate(re_ratio = (_.Wind_Forecast_MW + _.Solar_Forecast_MW) / _.MTLF)
    .mutate(re_diff = _.re_ratio - _.re_ratio.lag(1))
)

all_df

In [None]:
all_df_pd = all_df.to_pandas()
all_df_pd

In [None]:
all_df_pd.info()

In [None]:
len(all_df_pd.timestamp_mst.unique()) * len(all_df_pd.unique_id.unique())

In [None]:
node_groups = all_df_pd.unique_id.unique()
log.info(f'number of nodes: {len(node_groups)}')
node_groups

In [None]:
node_groups = [node for node in node_groups if 'PSCO_' in node]
log.info(f'number of nodes: {len(node_groups)}')
node_groups

In [None]:
all_df_pd = all_df_pd[all_df_pd.unique_id.isin(node_groups)].reset_index(drop=True)
all_df_pd

In [None]:
# all_df_pd.drop(['Averaged_Actual'], axis=1, errors='ignore', inplace=True)
all_df_pd.set_index('timestamp_mst', inplace=True)
all_df_pd

## Prep model training

In [None]:
FORECAST_HORIZON = 24*5
INPUT_CHUNK_LENGTH = 2*FORECAST_HORIZON

In [None]:
futr_cols = ['MTLF', 'Wind_Forecast_MW', 'Solar_Forecast_MW', 're_ratio', 're_diff']
past_cols = ['Averaged_Actual']
y = ['LMP']
ids = ['unique_id']

In [None]:
all_df_pd = all_df_pd[ids + y + past_cols + futr_cols]
all_df_pd 

In [None]:
train_start = all_df_pd.index.min() + pd.Timedelta(f'{2*INPUT_CHUNK_LENGTH}h')
test_end = all_df_pd.index.max() - pd.Timedelta(f'{2*FORECAST_HORIZON}h')
tr_tst_split =  test_end - pd.Timedelta(f'{2*INPUT_CHUNK_LENGTH}h')
log.info(f'train_start: {train_start}')
log.info(f'tr_tst_split: {tr_tst_split}')
log.info(f'test_end: {test_end}')

In [None]:
train_idx = (all_df_pd.index < tr_tst_split) & (all_df_pd.index > train_start)
test_idx = (all_df_pd.index > tr_tst_split) & (all_df_pd.index < test_end)
train_all = all_df_pd[train_idx]
train_all

In [None]:
test_all = all_df_pd[test_idx]
test_all

In [None]:
def fill_missing(series):
    for i in range(len(series)):
        transformer = MissingValuesFiller()
        series[i] = transformer.transform(series[i])

In [None]:
all_series = TimeSeries.from_group_dataframe(
    all_df_pd,
    group_cols=ids,
    value_cols=y,
    fill_missing_dates=True,
    freq='h',
)

fill_missing(all_series) 
all_series[0].plot()

In [None]:
train_series = TimeSeries.from_group_dataframe(
    train_all,
    group_cols=ids,
    value_cols=y,
    fill_missing_dates=True,
    freq='h',
)
fill_missing(train_series)
train_series[0].plot()

In [None]:
test_series = TimeSeries.from_group_dataframe(
    test_all,
    group_cols=ids,
    value_cols=y,
    fill_missing_dates=True,
    freq='h',
)
fill_missing(test_series)
test_series[0].plot()

In [None]:
futr_cov = TimeSeries.from_group_dataframe(
    all_df_pd,
    group_cols=ids,
    value_cols=futr_cols,
    fill_missing_dates=True,
    freq='h',
)
fill_missing(futr_cov)
futr_cov[0].plot()

In [None]:
past_cov = TimeSeries.from_group_dataframe(
    all_df_pd,
    group_cols=ids,
    value_cols=past_cols,
    fill_missing_dates=True,
    freq='h',
)
fill_missing(past_cov)
past_cov[0].plot()

## MLFlow setup

In [None]:
import mlflow

In [None]:
# mlflow.set_tracking_uri("sqlite:///mlruns.db")
mlflow.get_tracking_uri()

In [None]:
exp_name = 'spp_weis'

if mlflow.get_experiment_by_name(exp_name) is None:
    exp = mlflow.create_experiment(exp_name)
    
exp = mlflow.get_experiment_by_name(exp_name)
exp

## Get model signature

In [None]:
node_series = train_series[0]
future_cov_series = futr_cov[0]
past_cov_series = past_cov[0]

data = {
    'series': [node_series.to_json()],
    'past_covariates': [past_cov_series.to_json()],
    'future_covariates': [future_cov_series.to_json()],
    'n': FORECAST_HORIZON,
    'num_samples': 200
}
df = pd.DataFrame(data)

ouput_example = 'the endpoint return json as a string'

In [None]:
from mlflow.models import infer_signature
darts_signature = infer_signature(df, ouput_example)
darts_signature

## Set up hyperparameter tuning study

https://unit8co.github.io/darts/examples/17-hyperparameter-optimization.html?highlight=optuna

In [None]:
import torch
import pprint
# set up pretty printer
pp = pprint.PrettyPrinter(indent=2, sort_dicts=False)

def log_pretty(obj):
    pretty_out = f"{pp.pformat(obj)}"

    return f'{pretty_out}\n'
    
def build_fit_tsmixerx(
    series: List[TimeSeries]=train_series,
    val_series: List[TimeSeries]=test_series,
    future_covariates: List[TimeSeries]=futr_cov,
    past_covariates: List[TimeSeries]=past_cov,
    hidden_size: int=8,
    ff_size: int=96,
    num_blocks: int=1,
    forecast_horizon: int=FORECAST_HORIZON,
    input_chunk_length: int=INPUT_CHUNK_LENGTH,
    lr: float=2.5e-4,
    batch_size: int=64,
    n_epochs: int=7,
    force_reset: bool=True, # reset model if already exists
    callbacks=None,
):
    work_dir = os.getcwd() + '/model_checkpoints'
    MODEL_TYPE = "ts_mixer_model"
    quantiles = [0.01]+np.arange(0.05, 1, 0.05).tolist()+[0.99]
    
    #TODO: pick a metric...
    torch_metrics = MeanAbsoluteError()
    # torch_metrics = MeanSquaredError(squared=False)
    # torch_metrics = SymmetricMeanAbsolutePercentageError() # don't use...
    
    encoders = {
        "datetime_attribute": {
            "future": ["hour", "dayofweek", "month"], # 
            "past": ["hour", "dayofweek", "month"], # 
        },
        "position": {
            "past": ["relative"], 
            "future": ["relative"]
        },
        "transformer": Scaler(RobustScaler(), global_fit=True)
    }

    # common parameters across models
    model_params = {
        'hidden_size': hidden_size,
        'ff_size': ff_size,
        'num_blocks': num_blocks,
        'input_chunk_length': input_chunk_length,
        'output_chunk_length': forecast_horizon,
        'batch_size': batch_size,
        'n_epochs': n_epochs,
        'add_encoders': encoders,
        'likelihood': QuantileRegression(quantiles=quantiles),  # QuantileRegression is set per default
        'optimizer_kwargs': {"lr": lr},
        'random_state': 42,
        'torch_metrics': torch_metrics,
        'use_static_covariates': False,
        'save_checkpoints': True,
        'work_dir': work_dir,
        'model_name': MODEL_TYPE, # used for checkpoint saves
        'force_reset': force_reset, # reset model if already exists
    }

    # throughout training we'll monitor the validation loss for early stopping
    early_stopper = EarlyStopping("val_loss", min_delta=0.01, patience=3, verbose=True)
    if callbacks is None:
        callbacks = [early_stopper]
    else:
        callbacks = [early_stopper] + callbacks

    pl_trainer_kwargs = {"callbacks": callbacks}
    # model_params['pl_trainer_kwargs'] = pl_trainer_kwargs
    log.info(f'model_params: \n{log_pretty(model_params)}')
    
    model = TSMixerModel(**model_params)

    # train the model
    fit_params = {
        'series': train_series,
        'val_series': test_series,
        'future_covariates': futr_cov,
        'past_covariates': past_cov,
        'val_future_covariates': futr_cov,
        'val_past_covariates': past_cov,
    }
    model.fit(**fit_params)

    # reload best model over course of training
    model = TSMixerModel.load_from_checkpoint(
        work_dir=work_dir,
        model_name=MODEL_TYPE
    )
    
    model.MODEL_TYPE = MODEL_TYPE

    return model
    

In [None]:
def build_fit_tft(
    series: List[TimeSeries]=train_series,
    val_series: List[TimeSeries]=test_series,
    future_covariates: List[TimeSeries]=futr_cov,
    past_covariates: List[TimeSeries]=past_cov,
    hidden_size: int=12, # Hidden state size of the TFT. It is the main hyper-parameter and common across the internal TFT architecture.
    lstm_layers: int = 1, # Number of layers for the Long Short Term Memory (LSTM) Encoder and Decoder (1 is a good default).
    num_attention_heads: int=2, # Number of attention heads (4 is a good default)
    dropout: float=0.1,
    forecast_horizon: int=FORECAST_HORIZON,
    input_chunk_length: int=INPUT_CHUNK_LENGTH,
    lr: float=2.5e-4,
    batch_size: int=64,
    n_epochs: int=6,
    force_reset: bool=True, # reset model if already exists
    callbacks=None,
):
    work_dir = os.getcwd() + '/model_checkpoints'
    MODEL_TYPE = "tft_model"
    quantiles = [0.01]+np.arange(0.05, 1, 0.05).tolist()+[0.99]
    
    #TODO: pick a metric...
    torch_metrics = MeanAbsoluteError()
    # torch_metrics = MeanSquaredError(squared=False)
    # torch_metrics = SymmetricMeanAbsolutePercentageError() # don't use...
    
    encoders = {
        "datetime_attribute": {
            "future": ["hour", "dayofweek", "month"], # 
            "past": ["hour", "dayofweek", "month"], # 
        },
        "position": {
            "past": ["relative"], 
            "future": ["relative"]
        },
        "transformer": Scaler(RobustScaler(), global_fit=True)
    }

    # common parameters across models
    model_params = {
        'hidden_size': hidden_size,
        'lstm_layers': lstm_layers,
        'num_attention_heads': num_attention_heads,
        'dropout': dropout,
        'input_chunk_length': input_chunk_length,
        'output_chunk_length': forecast_horizon,
        'batch_size': batch_size,
        'n_epochs': n_epochs,
        'add_encoders': encoders,
        'likelihood': QuantileRegression(quantiles=quantiles),  # QuantileRegression is set per default
        'optimizer_kwargs': {"lr": lr},
        'random_state': 42,
        'torch_metrics': torch_metrics,
        'use_static_covariates': False,
        'save_checkpoints': True,
        'work_dir': work_dir,
        'model_name': MODEL_TYPE, # used for checkpoint saves
        'force_reset': force_reset, # reset model if already exists
    }
    

    # throughout training we'll monitor the validation loss for early stopping
    early_stopper = EarlyStopping("val_loss", min_delta=0.01, patience=3, verbose=True)
    if callbacks is None:
        callbacks = [early_stopper]
    else:
        callbacks = [early_stopper] + callbacks

    pl_trainer_kwargs = {"callbacks": callbacks}
    # model_params['pl_trainer_kwargs'] = pl_trainer_kwargs
    log.info(f'model_params: \n{log_pretty(model_params)}')

    model = TFTModel(**model_params)

    # train the model
    fit_params = {
        'series': train_series,
        'val_series': test_series,
        'future_covariates': futr_cov,
        'past_covariates': past_cov,
        'val_future_covariates': futr_cov,
        'val_past_covariates': past_cov,
    }
    model.fit(**fit_params)

    # reload best model over course of training
    model = TFTModel.load_from_checkpoint(
        work_dir=work_dir,
        model_name=MODEL_TYPE
    )
    
    model.MODEL_TYPE = MODEL_TYPE

    return model

In [None]:
# test build fit function
model = build_fit_tsmixerx(n_epochs=1)
# model = build_fit_tft(n_epochs=1)

In [None]:
model.MODEL_TYPE

In [None]:
model.model_params

In [None]:
preds = model.predict(
        series=train_series, 
        n=FORECAST_HORIZON,
        past_covariates=past_cov,
        future_covariates=futr_cov,
        num_samples=200,
)

errs = rmse(test_series, preds, n_jobs=-1, verbose=True)
errs = np.mean(errs)
errs

In [None]:
def get_ci_err(actual_series, pred_series, n_jobs=1, verbose=False):
    series_qs = pred_series.quantiles_df((0.1, 0.9))
    val_y = actual_series.pd_dataframe()
    
    eval_df = series_qs.merge(
        val_y,
        how='inner',
        left_index=True,
        right_index=True,
    )

    # 0 if u > l
    cover = (
        (eval_df['LMP_0.9'] > eval_df['LMP']) &
        (eval_df['LMP_0.1'] < eval_df['LMP'])
    ).mean() # should be about 80%

    return 100 * np.abs(cover - 0.8)

In [None]:
get_ci_err(test_series[0], preds[0])

In [None]:
def get_ci_cover_err(actual_series, preds):
    series_qs = preds.quantiles_df((0.1, 0.9))
    val_y = actual_series.pd_dataframe()
    
    eval_df = series_qs.merge(
        val_y,
        how='inner',
        left_index=True,
        right_index=True,
    )

    eval_df['over'] = (
        (eval_df['LMP'] - eval_df['LMP_0.9']).abs() * # difference
        (eval_df['LMP_0.9'] < eval_df['LMP'])   # is error
    )
    eval_df['under'] = (
        (eval_df['LMP'] - eval_df['LMP_0.1']).abs() * # difference
        (eval_df['LMP_0.1'] < eval_df['LMP'])   # is error
    )

    eval_df['total_err'] = eval_df.over + eval_df.under

    return eval_df['total_err'].mean()

In [None]:
np.mean([get_ci_cover_err(test_series[i], preds[i]) for i in range(len(preds))])

In [None]:
### retest this....
# acc = model.backtest(
#     series=test_series,
#     past_covariates=past_cov,
#     future_covariates=futr_cov,
#     retrain=False,
#     forecast_horizon=FORECAST_HORIZON,
#     stride=25,
#     metric=[get_ci_cover_err],
#     verbose=False,
#     num_samples=200,
    # retrain=False,
# )

In [None]:
def objective_tsmixer(trial):
    callback = [PyTorchLightningPruningCallback(trial, monitor="val_loss")]

    # Hyperparameters
    hidden_size = trial.suggest_int("hidden_size", 2, 16, step=2)
    ff_size = trial.suggest_int("ff_size", 16, 64, step=4)
    num_blocks = trial.suggest_int("num_blocks", 1, 8)
    lr = trial.suggest_float("lr", 5e-5, 5e-4, step=1e-5)
    n_epochs = trial.suggest_int("n_epochs", 3, 12)
    

    # build and train the TCN model with these hyper-parameters:
    model = build_fit_tsmixerx(
        hidden_size=hidden_size,
        ff_size=ff_size,
        num_blocks=num_blocks,
        lr=lr,
        n_epochs=n_epochs,
        callbacks=callback,
    )

    # Evaluate how good it is on the validation set
    preds = model.predict(
        series=train_series, 
        n=len(test_series[0]),
        past_covariates=past_cov,
        future_covariates=futr_cov,
        num_samples=200,
    )

    err_metric = rmse(test_series, preds, n_jobs=-1, verbose=True)
    err_metric = np.mean(err_metric)
    
    if err_metric!= np.nan:
        pass
    else:
        err_metric = float("inf")

    ci_error = np.mean([get_ci_cover_err(test_series[i], preds[i]) for i in range(len(preds))])
    

    return err_metric , ci_error



In [None]:
def objective_tft(trial):
    callback = [PyTorchLightningPruningCallback(trial, monitor="val_loss")]

    # Hyperparameters
    hidden_size = trial.suggest_int("hidden_size", 8, 24, step=2)
    lstm_layers = trial.suggest_int("lstm_layers", 1, 2)
    num_attention_heads = trial.suggest_int("num_attention_heads", 1, 4)
    lr = trial.suggest_float("lr", 1e-4, 5e-4, step=5e-5)
    n_epochs = trial.suggest_int("n_epochs", 3, 6)
    

    # build and train the TCN model with these hyper-parameters:
    model = build_fit_tft(
        hidden_size=hidden_size,
        lstm_layers=lstm_layers,
        num_attention_heads=num_attention_heads,
        lr=lr,
        n_epochs=n_epochs,
        callbacks=callback,
    )

    # Evaluate how good it is on the validation set
    preds = model.predict(
        series=train_series, 
        n=len(test_series[0]),
        past_covariates=past_cov,
        future_covariates=futr_cov,
        num_samples=200,
    )

    err_metric = rmse(test_series, preds, n_jobs=-1, verbose=True)
    err_metric = np.mean(err_metric)
    
    if err_metric!= np.nan:
        pass
    else:
        err_metric = float("inf")

    ci_error = np.mean([get_ci_cover_err(test_series[i], preds[i]) for i in range(len(preds))])
    

    return err_metric , ci_error

In [None]:
def print_callback(study, trial):
    best_smape = min(study.best_trials, key=lambda t: t.values[0])
    best_ci = min(study.best_trials, key=lambda t: t.values[1])
    best_total = min(study.best_trials, key=lambda t: sum(t.values))
    log.info(f"Current values: {trial.values}, Current params: \n{log_pretty(trial.params)}")
    log.info(f"Best SMAPE: {best_smape.values}, Best params: \n{log_pretty(best_smape.params)}")
    log.info(f"Best CI: {best_ci.values}, Best params: \n{log_pretty(best_ci.params)}")
    log.info(f"Best Total: {best_total.values}, Best params: \n{log_pretty(best_total.params)}")

## Start Experiment

In [None]:
if TEST_MODEL == 'tft':
    study = optuna.create_study(
        directions=["minimize", "minimize"],
        storage="sqlite:///spp_trials.db", 
        study_name="spp_weis_tft",
        load_if_exists=True,
    )
    
    objective_func = objective_tft

elif TEST_MODEL == 'tsmixer':
    study = optuna.create_study(
        directions=["minimize", "minimize"],
        storage="sqlite:///spp_trials.db", 
        study_name="spp_weis_tsmixer",
        load_if_exists=True,
    )
    objective_func = objective_tsmixer

else:
    raise ValueError(f'Unexpected TEST_MODEL: {TEST_MODEL}')
    

In [None]:
# if RUN_EXP:
    # or limit time
    # study.optimize(objective_func, timeout=7200, callbacks=[print_callback])
study.optimize(objective_func, n_trials=NUM_TRIALS, callbacks=[print_callback])
    

In [None]:
target_names = ['RMSE', 'CI_ERROR']

In [None]:
for i, target_name in enumerate(target_names):
    fig = plot_optimization_history(study, target=lambda t: t.values[i], target_name=target_name)
    fig.show()

In [None]:
# for i, target_name in enumerate(target_names):
#     fig = plot_optimization_history(study, target=lambda t: t.values[i], target_name=target_name)
#     fig.show()

In [None]:
for i, target_name in enumerate(target_names):
    fig = plot_contour(study, params=["lr", "n_epochs"], target=lambda t: t.values[i], target_name=target_name)
    fig.show()

In [None]:
plot_param_importances(study)

In [None]:
plot_pareto_front(study, target_names=target_names)

In [None]:
plot_pareto_front(study, target_names=target_names, include_dominated_trials=False)

In [None]:
# TODO: think about how to wieght values to get best model
best_model = min(study.best_trials, key=lambda t: sum(t.values))
# best_model = min(study.best_trials, key=lambda t: t.values[0])
log.info(f"Best number: {best_model.number}")
log.info(f"Best values: {best_model.values}")
log.info(f"Best params: \n{log_pretty(best_model.params)}")

## Refit and log model with best params

In [None]:
with mlflow.start_run(experiment_id=exp.experiment_id) as run:
    # fit model with best params from study
    if 'lstm_layers' in best_model.params:
        model = build_fit_tft(**best_model.params)
    elif 'ff_size' in best_model.params:
        model = build_fit_tsmixerx(**best_model.params)
    else:
        raise ValueError(f'Unexpected model with params: \n{log_pretty(best_model.params)}')
    
    log.info(f'run.info: \n{run.info}')
    artifact_path = "model_artifacts"
    metrics = {}
    params = model.model_params
    
    # back test on validation data
    acc = model.backtest(
        series=test_series,
        # series=all_series,
        past_covariates=past_cov,
        future_covariates=futr_cov,
        retrain=False,
        forecast_horizon=params['output_chunk_length'],
        stride=25,
        metric=[mae, rmse],
        verbose=False,
    )

    log.info(f'BACKTEST: acc: {acc}')
    log.info(f'BACKTEST: np.mean(acc, axis=0): {np.mean(acc, axis=0)}')
    acc_df = pd.DataFrame(
        np.mean(acc, axis=0).reshape(1,-1),
        columns=['mae', 'rmse']
    )

    # add metrics
    metrics['test_mae'] = acc_df.mae[0]
    metrics['test_rmse'] = acc_df.rmse[0]

    # final training
    final_train_series = test_series
    log.info('final training')
    model.fit(
            series=test_series,
            past_covariates=past_cov,
            future_covariates=futr_cov,
            verbose=True,
            # epochs=params['n_epochs_final'], # continue training
            )
    
    # final model back test on validation data
    acc = model.backtest(
            series=test_series,
            past_covariates=past_cov,
            future_covariates=futr_cov,
            retrain=False,
            forecast_horizon=params['output_chunk_length'],
            stride=25,
            metric=[mae, rmse],
            verbose=False,
        )

    log.info(f'TEST ACC: acc: {acc}')
    log.info(f'TEST ACC: np.mean(acc, axis=0): {np.mean(acc, axis=0)}')
    acc_df = pd.DataFrame(
        np.mean(acc, axis=0).reshape(1,-1),
        columns=['mae', 'rmse']
    )

    # add and log metrics
    metrics['mae_final'] = acc_df.mae[0]
    metrics['rmse_final'] = acc_df.rmse[0]
    mlflow.log_metrics(metrics)

    # set up path to save model
    model_path = '/'.join([artifact_path, model.MODEL_TYPE])

    shutil.rmtree(artifact_path, ignore_errors=True)
    os.makedirs(artifact_path)

    # log params
    mlflow.log_params(params)

    # save model files (model, model.ckpt) 
    # and load them to artifacts when logging the model
    model.save(model_path)

    # save MODEL_TYPE to artifacts
    # this will be used to load the model from the artifacts
    model_type_path = '/'.join([artifact_path, 'MODEL_TYPE.pkl'])
    with open(model_type_path, 'wb') as handle:
        pickle.dump(model.MODEL_TYPE, handle)
    
    # map model artififacts in dictionary
    artifacts = {
        'model': model_path,
        'model.ckpt': model_path+'.ckpt',
        'MODEL_TYPE': model_type_path,
    }
    
    # log model
    # https://www.mlflow.org/docs/latest/tutorials-and-examples/tutorial.html#pip-requirements-example
    mlflow.pyfunc.log_model(
        artifact_path='GlobalForecasting',
        code_path=['notebooks/model_training/darts_wrapper.py'],
        signature=darts_signature,
        artifacts=artifacts,
        # model will get loaded from artifacts, we don't need instantiate with one
        python_model=DartsGlobalModel(), 
        pip_requirements=["-r notebooks/model_training/requirements.txt"],
    )


## Get latest run and test predicting

In [None]:
runs = mlflow.search_runs(
    experiment_ids = exp.experiment_id,
    # order_by=['metrics.test_mae']
    order_by=['end_time']
    )

runs.sort_values('end_time', ascending=False, inplace=True)
runs.head()

In [None]:
best_run_id = runs.run_id.iloc[0]
best_run_id

In [None]:
runs['artifact_uri'].iloc[0]

In [None]:
model_path = runs['artifact_uri'].iloc[0] + '/GlobalForecasting'

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_path)

In [None]:
plot_ind = 3
plot_series = all_series[plot_ind]

In [None]:
plot_series.static_covariates.unique_id.LMP

In [None]:
plot_series.plot()

In [None]:
plot_series.start_time()

In [None]:
plot_end_times = [
    '2024-07-01T23:00:00',
    '2024-07-15T23:00:00',
    '2024-07-14T23:00:00',
    '2024-07-20T23:00:00',
    '2024-08-17T23:00:00',
    '2024-07-29T23:00:00',
]


for plot_end_time in plot_end_times:
    plot_end_time = min(
        plot_series.end_time() - pd.Timedelta(f'{INPUT_CHUNK_LENGTH+1}h'), 
        pd.Timestamp(plot_end_time)
    )
    log.info(f'plot_end_time: {plot_end_time}')
    
    plot_node_name = plot_series.static_covariates.unique_id.LMP
    
    # if test_end_time < test_series.end_time():
    node_series = plot_series.drop_after(plot_end_time)
        
    log.info(f'plot_end_time: {plot_end_time}')
    log.info(f'node_series.end_time(): {node_series.end_time()}')
    future_cov_series = futr_cov[0]
    past_cov_series = past_cov[0]
    
    
    data = {
        'series': [node_series.to_json()],
        'past_covariates': [past_cov_series.to_json()],
        'future_covariates': [future_cov_series.to_json()],
        'n': FORECAST_HORIZON,
        'num_samples': 200
    }
    df = pd.DataFrame(data)
    
    plot_cov_df = future_cov_series.pd_dataframe()
    plot_cov_df = (
        plot_cov_df
        .reset_index()
        .rename(columns={'timestamp_mst':'time', 're_ratio': 'Ratio'})
    )
    
    # Predict on a Pandas DataFrame.
    df['num_samples'] = 500
    pred = loaded_model.predict(df)
    preds = TimeSeries.from_json(pred)
    
    q_df = plotting.get_quantile_df(preds)
    
    plot_df = plotting.get_mean_df(preds).merge(
        plotting.get_quantile_df(preds),
        left_index=True,
        right_index=True,
    )
    
    lmp_df = lmp.to_pandas().rename(
        columns={
            'LMP': 'LMP_HOURLY',
            'unique_id':'node', 
            'timestamp_mst':'time'
        })
    
    plot_df = plotting.get_plot_df(
            TimeSeries.from_json(pred),
            plot_cov_df,
            lmp_df,
            plot_node_name,
        )
    plot_df.rename(columns={'mean':'mean_fcast'}, inplace=True)
    plot_df
    
    plotting.plotly_forecast(plot_df, plot_node_name, show_fig=True)

In [None]:
df