In [None]:
import os
import io
import tempfile
import shutil
import pickle
import json
import sys
from time import time

from typing import List
import tempfile

import pandas as pd
import boto3
import torch
import warnings
import logging
from dotenv import load_dotenv

from darts.models import (
    TFTModel,
    TiDEModel,
    TSMixerModel,
    NaiveEnsembleModel,
)

warnings.filterwarnings("ignore")
load_dotenv()
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("model_retrain")
t0 = time()

In [None]:
# Navigate to project root so src/ imports work regardless of where
# the notebook is launched from (local dev, Databricks, CI, etc.)
while 'src' not in os.listdir():
    os.chdir('..')

src_path = "src"
if os.path.isdir(src_path) and src_path not in sys.path:
    sys.path.insert(0, src_path)

log.info(f"os.getcwd(): {os.getcwd()}")

import data_engineering as de
import parameters
import utils
from modeling import build_fit_tsmixerx, build_fit_tft, build_fit_tide

In [None]:
# On Databricks, pull AWS creds from the secrets store.
# Locally, fall back to .env file via python-dotenv.
if 'dbutils' in locals():
    os.environ['AWS_ACCESS_KEY_ID'] = dbutils.secrets.get(scope = "aws", key = "AWS_ACCESS_KEY_ID")
    os.environ['AWS_SECRET_ACCESS_KEY'] = dbutils.secrets.get(scope = "aws", key = "AWS_SECRET_ACCESS_KEY")
    os.environ['AWS_S3_BUCKET'] = dbutils.secrets.get(scope = "aws", key = "AWS_S3_BUCKET")
    os.environ['AWS_S3_FOLDER'] = dbutils.secrets.get(scope = "aws", key = "AWS_S3_FOLDER")
    import mlflow
    mlflow.autolog(disable=True)
else:
    print('not on DBX')

from dotenv import load_dotenv
load_dotenv()

In [None]:
log.info(f"FORECAST_HORIZON: {parameters.FORECAST_HORIZON}")
log.info(f"INPUT_CHUNK_LENGTH: {parameters.INPUT_CHUNK_LENGTH}")
log.info(f"MODEL_NAME: {parameters.MODEL_NAME}")

In [None]:
AWS_S3_BUCKET = os.getenv("AWS_S3_BUCKET")
AWS_S3_FOLDER = os.getenv("AWS_S3_FOLDER")
log.info(f'{AWS_S3_FOLDER = }')

In [None]:
s3 = boto3.client("s3")

## Connect to database and prepare data

In [None]:
# Create a DuckDB connection backed by S3 parquet files
con = de.create_database()

In [None]:
log.info("preparing lmp data")
lmp = de.prep_lmp(con)
lmp_df = lmp.to_pandas().rename(
    columns={
        "LMP": "LMP_HOURLY",
        "unique_id": "node",
        "timestamp_mst": "time",
    }
)

log.info("preparing covariate data")
all_df_pd = de.all_df_to_pandas(de.prep_all_df(con))
all_df_pd.info()

lmp_all, train_all, test_all, train_test_all = de.get_train_test_all(con)
con.close()

In [None]:
# Convert DataFrames to Darts TimeSeries objects for model training/prediction
all_series = de.get_series(lmp_all)
train_test_all_series = de.get_series(train_test_all)
train_series = de.get_series(train_all)
test_series = de.get_series(test_all)

# Future covariates (known ahead of time, e.g. weather/load forecasts)
# and past covariates (only known up to current time)
futr_cov = de.get_futr_cov(all_df_pd)
past_cov = de.get_past_cov(all_df_pd)

## Train models

In [None]:
def _():
    # Train TOP_N TSMixer models with different hyperparameter sets
    models_tsmixer = []
    if parameters.USE_TSMIXER:
        for i, param in enumerate(parameters.TSMIXER_PARAMS[: parameters.TOP_N]):
            print(f"\ni: {i} \t" + "*" * 25, flush=True)
            model_tsmixer = build_fit_tsmixerx(
                series=train_test_all_series,
                val_series=test_series,
                future_covariates=futr_cov,
                past_covariates=past_cov,
                **param,
            )
            models_tsmixer += [model_tsmixer]

    return models_tsmixer 

models_tsmixer = _()

In [None]:
def _():
    # Train TOP_N TiDE models with different hyperparameter sets
    models_tide = []
    if parameters.USE_TIDE:
        for i, param in enumerate(parameters.TIDE_PARAMS[: parameters.TOP_N]):
            print(f"\ni: {i} \t" + "*" * 25, flush=True)
            model_tide = build_fit_tide(
                series=train_test_all_series,
                val_series=test_series,
                future_covariates=futr_cov,
                past_covariates=past_cov,
                **param,
            )
            models_tide += [model_tide]

    return models_tide


models_tide = _()

In [None]:
def _():
    # Train TOP_N TFT models with different hyperparameter sets
    models_tft = []
    if parameters.USE_TFT:
        for i, param in enumerate(parameters.TFT_PARAMS[: parameters.TOP_N]):
            print(f"\ni: {i} \t" + "*" * 25, flush=True)
            model_tft = build_fit_tft(
                series=train_test_all_series,
                val_series=test_series,
                future_covariates=futr_cov,
                past_covariates=past_cov,
                **param,
            )
            models_tft += [model_tft]

    return models_tft

models_tft = _()

## Save and upload models

In [None]:
# Create a timestamped folder path for this retrain run's artifacts.
# artifact_folder: relative path passed to utils.get_loaded_models()
# artifact_path:   full S3 key prefix for uploading files
utc_timestamp = pd.Timestamp.utcnow()
log.info(f'{utc_timestamp = }')

folder_time = utc_timestamp.strftime('%Y-%m-%d_%H-%M-%S') + '/'
log.info(f'{folder_time = }')

artifact_folder = 'model_retrains/' + folder_time
log.info(f'{artifact_folder = }')

artifact_path = AWS_S3_FOLDER + artifact_folder
log.info(f'{artifact_path = }')

In [None]:
def _():
    # Upload all trained models to S3 under the timestamped artifact path.
    # Darts' .save() produces two files per model:
    #   <name>.pt      - pickled model wrapper (config, training state)
    #   <name>.pt.ckpt - PyTorch Lightning checkpoint (neural network weights)
    # Both are required for Darts' .load() to fully restore a trained model.

    upload_paths = []

    def model_to_tmp_upload(
        m,
        name: str,
        AWS_S3_BUCKET: str=AWS_S3_BUCKET,
        artifact_path: str=artifact_path,
    ):
        with tempfile.TemporaryDirectory() as tmpdir:
            model_path = os.path.join(tmpdir, name)
            m.save(model_path)

            # Upload the model wrapper file
            upload_path = artifact_path + name
            s3.upload_file(model_path, AWS_S3_BUCKET, upload_path)
            log.info(f'Uploaded: {upload_path}')

            # Upload the checkpoint file with neural network weights
            ckpt_path = model_path + '.ckpt'
            if os.path.exists(ckpt_path):
                ckpt_upload_path = upload_path + '.ckpt'
                s3.upload_file(ckpt_path, AWS_S3_BUCKET, ckpt_upload_path)
                log.info(f'Uploaded: {ckpt_upload_path}')

        return upload_path

    # Upload training timestamp so the Shiny app can display when models were last trained
    buffer = io.BytesIO()
    pickle.dump(utc_timestamp, buffer)
    buffer.seek(0)
    upload_path = artifact_path + "TRAIN_TIMESTAMP.pkl"
    s3.put_object(
        Bucket=AWS_S3_BUCKET, 
        Key=upload_path, 
        Body=buffer,
    )
    log.info(f'Uploaded: {upload_path}')
    upload_paths+=[upload_path]

    for i, m in enumerate(models_tide):
        upload_paths+=[model_to_tmp_upload(m, f"tide_{i}.pt")]

    for i, m in enumerate(models_tsmixer):
        upload_paths+=[model_to_tmp_upload(m, f"tsmixer_{i}.pt")]

    for i, m in enumerate(models_tft):
        upload_paths+=[model_to_tmp_upload(m, f"tft_{i}.pt")]

    return upload_paths

upload_paths = _()

In [None]:
# List all model files just uploaded to S3 for verification and loading
loaded_models_for_test = utils.get_loaded_models(artifact_folder)
loaded_models_for_test

## Test loading models from S3 and doing inference

### Load models by type

In [None]:
def _():
    # Re-download models from S3 and load them to verify the save/load
    # round-trip works before promoting to champion.
    # Each model needs both .pt and .pt.ckpt files in the same directory
    # for Darts' .load() to fully restore the trained model.

    def get_checkpoints(
        model_filter: str, # 'tsmixer_', 'tide_', or 'tft_'
        loaded_models_for_test: List[str]=loaded_models_for_test,
    ):
        """Filter S3 keys to just the main .pt files (not .ckpt or .pkl)."""
        return [
            f for f in loaded_models_for_test
            if model_filter in f and ".pt" in f and ".ckpt" not in f
            and "TRAIN_TIMESTAMP.pkl" not in f
        ]

    def load_model_from_s3(model_class, key):
        """Download a model's .pt + .pt.ckpt files to a temp dir and load."""
        with tempfile.TemporaryDirectory() as tmpdir:
            filename = key.split('/')[-1]
            local_path = os.path.join(tmpdir, filename)

            # Download the model wrapper file
            s3.download_file(Bucket=AWS_S3_BUCKET, Key=key, Filename=local_path)

            # Download the checkpoint file with neural network weights
            try:
                s3.download_file(
                    Bucket=AWS_S3_BUCKET,
                    Key=key + '.ckpt',
                    Filename=local_path + '.ckpt',
                )
            except Exception:
                log.warning(f"No checkpoint file found for {key}")

            log.info(f"loading model: {key}")
            model = model_class.load(local_path, map_location=torch.device("cpu"))

        return model

    ts_mixer_ckpts = get_checkpoints("tsmixer_")
    ts_mixer_forecasting_models = [load_model_from_s3(TSMixerModel, m) for m in ts_mixer_ckpts]

    tide_ckpts = get_checkpoints("tide_")
    tide_forecasting_models = [load_model_from_s3(TiDEModel, m) for m in tide_ckpts]

    tft_ckpts = get_checkpoints("tft_")
    tft_forecasting_models = [load_model_from_s3(TFTModel, m) for m in tft_ckpts]

    forecasting_models = (
        ts_mixer_forecasting_models
        + tide_forecasting_models
        + tft_forecasting_models
    )

    return forecasting_models

forecasting_models = _()

## Create ensemble model and test predictions

In [None]:
def _():
    # Combine all reloaded models into a NaiveEnsembleModel (simple average
    # of predictions). train_forecasting_models=False since they're already trained.
    log.info("loading model from checkpoints")
    loaded_model = NaiveEnsembleModel(
        forecasting_models=forecasting_models,
        train_forecasting_models=False,
    )

    # Smoke test: run a short prediction on one node to verify the
    # ensemble works end-to-end before promoting to champion.
    log.info("test getting predictions")
    plot_ind = 3
    plot_series = all_series[plot_ind]

    plot_end_time = plot_series.end_time() - pd.Timedelta(
        f"{parameters.INPUT_CHUNK_LENGTH + 1}h"
    )
    log.info(f"plot_end_time: {plot_end_time}")

    plot_node_name = plot_series.static_covariates.unique_id.LMP
    node_series = plot_series.drop_after(plot_end_time)
    log.info(f"plot_end_time: {plot_end_time}")
    log.info(f"node_series.end_time(): {node_series.end_time()}")
    future_cov_series = futr_cov[0]
    past_cov_series = past_cov[0]

    pred = loaded_model.predict(
        series=node_series,
        past_covariates=past_cov_series,
        future_covariates=future_cov_series,
        n=5,
        num_samples=2,
    )

    log.info(f"pred: {pred}")

    return pred

pred = _()

In [None]:
assert pred is not None

In [None]:
pred.pd_dataframe()

In [None]:
def _():
    # Upload champion.json to S3_models/ so the Shiny app knows which
    # retrain folder contains the current production models.
    # Only promoted if the smoke-test prediction succeeded.
    if pred is not None:
        champion_json = {
            "champion": folder_time,
            "champion_artifact_folder": artifact_folder,
            "champion_artifact_path": artifact_path,
        }

        buffer = io.BytesIO(json.dumps(champion_json).encode("utf-8"))
        champion_key = AWS_S3_FOLDER + "S3_models/champion.json"
        s3.put_object(Bucket=AWS_S3_BUCKET, Key=champion_key, Body=buffer)
        log.info(f"Uploaded champion model json: {champion_key}")
        log.info(f"champion_json: {champion_json}")
    else:
        log.warning("Prediction failed, not saving json")

_()

In [None]:
# from time import time as _time

t1 = time()
log.info("finished retraining")
log.info(f"total time (min): {(t1 - t0) / 60:.2f}")