In [None]:
# TODO:
#    1. early stopping not working optuna experiment

In [None]:
import os
import shutil
import pickle
import random
import sys
import numpy as np
import pandas as pd
import duckdb
from typing import List

import boto3

import requests
from io import StringIO

import ibis
import ibis.selectors as s
from ibis import _
ibis.options.interactive = True

from sklearn.preprocessing import RobustScaler

import torch

from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import (
    Scaler,
    MissingValuesFiller,
    Mapper,
    InvertibleMapper,
)
from darts.dataprocessing import Pipeline
from darts.metrics import mape, smape, mae, ope, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, IceCreamHeaterDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression, GumbelLikelihood, GaussianLikelihood

from darts import TimeSeries
from darts.utils.timeseries_generation import (
    gaussian_timeseries,
    linear_timeseries,
    sine_timeseries,
)
from darts.models import (
    TFTModel,
    TiDEModel,
    DLinearModel,
    NLinearModel,
    TSMixerModel,
    NaiveEnsembleModel,
    RegressionEnsembleModel,
)


from torchmetrics import (
    SymmetricMeanAbsolutePercentageError, 
    MeanAbsoluteError, 
    MeanSquaredError,
)

from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import mlflow

import warnings
warnings.filterwarnings("ignore")

from dotenv import load_dotenv
load_dotenv()

# logging
import logging

# define log
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

In [None]:
# https://github.com/Lightning-AI/pytorch-lightning/issues/3431
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)

In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import (
    plot_optimization_history,
    plot_contour,
    plot_param_importances,
    plot_pareto_front,
)

In [None]:
os.chdir('../..')

In [None]:
# custom modules
import src.data_engineering as de
from src import parameters
from src import plotting
from src.modeling import (
    get_ci_err, build_fit_tsmixerx, build_fit_tide, build_fit_tft, log_pretty
)

## will be loaded from root when deployed
from src.darts_wrapper import DartsGlobalModel

In [None]:
log.info(f'FORECAST_HORIZON: {parameters.FORECAST_HORIZON}')
log.info(f'INPUT_CHUNK_LENGTH: {parameters.INPUT_CHUNK_LENGTH}')

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
# optuna.delete_study(study_name="spp_weis_tide", storage="sqlite:///spp_trials.db")

## Data prep

In [None]:
# client for uploading model weights
s3 = boto3.client('s3')

In [None]:
# connect to database
# con = ibis.duckdb.connect("data/spp.ddb", read_only=True)
con = ibis.duckdb.connect()
log.info('getting lmp data from s3')
con.read_parquet('s3://spp-weis/data/lmp.parquet', 'lmp')
log.info('getting mtrf data from s3')
con.read_parquet('s3://spp-weis/data/mtrf.parquet', 'mtrf')
log.info('getting mtlf data from s3')
con.read_parquet('s3://spp-weis/data/mtlf.parquet', 'mtlf')
log.info('finished getting data from s3')

con.list_tables()

In [None]:
lmp = de.prep_lmp(con)
lmp

In [None]:
# needed for plotting
lmp_df = lmp.to_pandas().rename(
    columns={
        'LMP': 'LMP_HOURLY',
        'unique_id':'node', 
        'timestamp_mst':'time'
    })

In [None]:
mtrf = de.prep_mtrf(con)
mtrf

In [None]:
mtlf = de.prep_mtlf(con)
mtlf

In [None]:
all_df = de.prep_all_df(con)
all_df

In [None]:
all_df_pd = de.all_df_to_pandas(de.prep_all_df(con))
all_df_pd

In [None]:
all_df_pd.info()

## Prep model training data

In [None]:
lmp_all, train_all, test_all, train_test_all = de.get_train_test_all(con)

In [None]:
lmp_all

In [None]:
all_series = de.get_series(lmp_all)
all_series[0].plot()

In [None]:
train_test_all_series = de.get_series(train_test_all)
train_test_all_series[0].plot()

In [None]:
train_series = de.get_series(train_all)
train_series[0].plot()

In [None]:
test_series = de.get_series(test_all)
test_series[0].plot()

In [None]:
futr_cov = de.get_futr_cov(all_df_pd)
futr_cov[0].plot()

In [None]:
past_cov = de.get_past_cov(all_df_pd)
past_cov[0].plot()

In [None]:
con.disconnect()

## Pretrain models with the best params

In [None]:
models_tsmixer = []
if parameters.USE_TSMIXER:
    for i, param in enumerate(parameters.TSMIXER_PARAMS[:parameters.TOP_N]):
        print(f'\ni: {i} \t' + '*' * 25, flush=True)
        model_tsmixer = build_fit_tsmixerx(
            series=train_test_all_series,
            val_series=test_series,
            future_covariates=futr_cov,
            past_covariates=past_cov,
            **param
        )
        models_tsmixer += [model_tsmixer]

In [None]:
models_tide = []
if parameters.USE_TIDE:
    for i, param in enumerate(parameters.TIDE_PARAMS[:parameters.TOP_N]):
        print(f'\ni: {i} \t' + '*' * 25, flush=True)
        model_tide = build_fit_tide(
            series=train_test_all_series,
            val_series=test_series,
            future_covariates=futr_cov,
            past_covariates=past_cov,
            **param
        )
        models_tide += [model_tide]

In [None]:
models_tft = []
if parameters.USE_TFT:
    for i, param in enumerate(parameters.TFT_PARAMS[:parameters.TOP_N]):
        print(f'\ni: {i} \t' + '*' * 25, flush=True)
        model_tft = build_fit_tft(
            series=train_test_all_series,
            val_series=test_series,
            future_covariates=futr_cov,
            past_covariates=past_cov,
            **param
        )
        models_tft += [model_tft]

## Create ensemble model

In [None]:
forecasting_models = models_tsmixer + models_tide + models_tft

In [None]:
loaded_model = NaiveEnsembleModel(
    forecasting_models=forecasting_models, 
    train_forecasting_models=False
)

## Plot test predictions

In [None]:
plot_ind = 3
plot_series = all_series[plot_ind]

In [None]:
plot_series.static_covariates.unique_id.LMP

In [None]:
plot_series.plot()

In [None]:
plot_end_times = pd.date_range(
    end=test_series[plot_ind].end_time(),
    periods=10,
    freq='d',
)

plot_end_times

In [None]:
for plot_end_time in plot_end_times:
    log.info(f'plot_end_time: {plot_end_time}')
    
    plot_node_name = plot_series.static_covariates.unique_id.LMP
    
    # if test_end_time < test_series.end_time():
    node_series = plot_series.drop_after(plot_end_time)
        
    log.info(f'plot_end_time: {plot_end_time}')
    log.info(f'node_series.end_time(): {node_series.end_time()}')
    future_cov_series = futr_cov[0]
    past_cov_series = past_cov[0]
    
    data = {
        'series': [node_series.to_json()],
        'past_covariates': [past_cov_series.to_json()],
        'future_covariates': [future_cov_series.to_json()],
        'n': parameters.FORECAST_HORIZON,
        'num_samples': 200
    }
    df = pd.DataFrame(data)
    
    plot_cov_df = future_cov_series.pd_dataframe()
    plot_cov_df = (
        plot_cov_df
        .reset_index()
        .rename(columns={'timestamp_mst':'time', 're_ratio': 'Ratio'})
    )
    
    # Predict on a Pandas DataFrame.
    df['num_samples'] = 500
    
    # for mlflow pyfunc model
    # preds_json = loaded_model.predict(df)
    # preds = TimeSeries.from_json(preds_json)

    # for darts model
    preds = loaded_model.predict(
        series=node_series,
        past_covariates=past_cov_series,
        future_covariates=future_cov_series,
        n=parameters.FORECAST_HORIZON,
        num_samples=500,
    )
    
    q_df = plotting.get_quantile_df(preds, plot_node_name)
    
    plot_df = plotting.get_mean_df(preds, plot_node_name).merge(
        plotting.get_quantile_df(preds, plot_node_name),
        left_index=True,
        right_index=True,
    )
    
    plot_df = plotting.get_plot_df(
            # TimeSeries.from_json(pred),
            preds,
            plot_cov_df,
            lmp_df,
            plot_node_name,
        )
    plot_df.rename(columns={'mean':'mean_fcast'}, inplace=True)
    plot_df
    
    plotting.plotly_forecast(plot_df, plot_node_name, show_fig=True)

In [None]:
# df

## Depreicated MLFlow code

In [None]:
## MLFlow set up
# # mlflow.set_tracking_uri("sqlite:///mlruns.db")
# log.info(f'mlflow.get_tracking_uri(): {mlflow.get_tracking_uri()}')
# exp_name = 'spp_weis'

# if mlflow.get_experiment_by_name(exp_name) is None:
#     exp = mlflow.create_experiment(exp_name)
    
# exp = mlflow.get_experiment_by_name(exp_name)
# exp

In [None]:
## Get model signature
# node_series = train_series[0]
# future_cov_series = futr_cov[0]
# past_cov_series = past_cov[0]

# data = {
#     'series': [node_series.to_json()],
#     'past_covariates': [past_cov_series.to_json()],
#     'future_covariates': [future_cov_series.to_json()],
#     'n': parameters.FORECAST_HORIZON,
#     'num_samples': 200
# }

# df = pd.DataFrame(data)

# ouput_example = 'the endpoint return json as a string'

# from mlflow.models import infer_signature
# darts_signature = infer_signature(df, ouput_example)
# darts_signature

In [None]:
## Get latest run
# runs = mlflow.search_runs(
#     experiment_ids = exp.experiment_id,
#     # order_by=['metrics.test_mae']
#     order_by=['end_time']
#     )

# runs.sort_values('end_time', ascending=False, inplace=True)
# runs.head()

# best_run_id = runs.run_id.iloc[0]
# best_run_id

# runs['artifact_uri'].iloc[0]

# model_path = runs['artifact_uri'].iloc[0] + '/GlobalForecasting'

# loaded_model = mlflow.pyfunc.load_model(model_path)