# Forecasting Baselines - Running the experiments

This scripts takes very long to run (possibly days if not parralelized).
This is mostly due to the local and univariate models, which need to be trained for each time series
individually. The global models are trained on all time series at once and are therefore much, much faster to evaluate.
Note, that in order to speed up the evaluation of the local models, we only evaluate these models them on the first `evaluate_first_n_companies` companies.

Also, this notebook will prompt you for an *Weights and Biases* API key. You can get one for free at https://wandb.ai.
Alternatively, you can run the notebook without Weights and Biases by commenting out `wandb.login()` and by passing `use_logger=False` to `ModelEvaluator()`.


## Setup

In [None]:
import os

os.environ["WANDB_SILENT"] = "true"

In [None]:
from pathlib import Path
import logging
import pickle
from argparse import ArgumentParser
import warnings

import torch
from torch import optim
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from sortedcontainers import SortedDict
from darts import TimeSeries
from darts.models.forecasting.forecasting_model import ForecastingModel
import wandb
from rich.progress import track

from proprietary_data import CompanyFundamentalsKind, ALL_FEATURE_NAMES, KEY_FEATURE_NAMES
from proprietary_data.darts import (
    load_company_fundamentals_as_darts_time_series,
    TimeSeriesContainer,
    generate_splits,
)
from forecasting_cfs.eval_model import ModelEvaluator, ForecastingResult

In [None]:
# Make sure we are logged in to wandb
assert wandb.login()

In [None]:
# Darts is extremely noisy
logging.disable()

In [None]:
def is_interactive():
    import __main__ as main

    return not hasattr(main, "__file__")


if __name__ == "__main__":
    if not is_interactive():
        raise RuntimeError(
            "You need to indent everything below this by one after exporting to a standalone file"
        )


TARGET_ALL_FEATURES = False
USE_STATIC_COVARIATES = True
USE_REVIN = True
NORMALIZE_LOCAL_MODELS = True
UNCERTAINTY = True

# We do not provide a seasonality guess since even for quarterly data, the seasonality is not always 4
# as automatic seasonality detection indicates

HORIZON_LOOKBACK = 12  # 16 would also be fine
HORIZON_FORECAST = 4
STRIDE = 1
WINDOW_MODE = "expanding"

only_from_to: None | tuple[int, int] = None


if not is_interactive():
    parser = ArgumentParser()

    parser.add_argument(
        "--use_all_features", default=TARGET_ALL_FEATURES, action="store_true"
    )
    parser.add_argument(
        "--no-use_all_features", dest="use_all_features", action="store_false"
    )
    parser.add_argument(
        "--use_static_covariates", default=USE_STATIC_COVARIATES, action="store_true"
    )
    parser.add_argument(
        "--no-use_static_covariates", dest="use_static_covariates", action="store_false"
    )
    parser.add_argument("--use_revin", default=USE_REVIN, action="store_true")
    parser.add_argument("--no-use_revin", dest="use_revin", action="store_false")
    parser.add_argument("--use_uncertainty", default=UNCERTAINTY, action="store_true")
    parser.add_argument(
        "--no-use_uncertainty", dest="use_uncertainty", action="store_false"
    )

    parser.add_argument("--lookback", type=int, default=HORIZON_LOOKBACK)
    parser.add_argument("--forecast", type=int, default=HORIZON_FORECAST)
    parser.add_argument("--stride", type=int, default=STRIDE)
    parser.add_argument(
        "--window", default=WINDOW_MODE, choices=["sliding", "expanding"]
    )

    parser.add_argument(
        "--limit",
        type=str,
        default="",
        help="Limit the range in index steps. Format: `from,to`, where the latter is exclusive.",
    )

    args = parser.parse_args()

    TARGET_ALL_FEATURES = args.use_all_features
    USE_STATIC_COVARIATES = args.use_static_covariates
    USE_REVIN = args.use_revin
    UNCERTAINTY = args.use_uncertainty

    HORIZON_LOOKBACK = args.lookback
    HORIZON_FORECAST = args.forecast
    STRIDE = args.stride
    WINDOW_MODE = args.window

    if args.limit:
        only_from_to = tuple(map(int, args.limit.split(",")))
        assert len(only_from_to) == 2

In [None]:
run_name = f"DEBUG-statics_{USE_STATIC_COVARIATES}-revin_{USE_REVIN}-stride{STRIDE}_lb{HORIZON_LOOKBACK}-{WINDOW_MODE}-uq_{UNCERTAINTY}"
OUT_PATH = Path("forecast_baselines") / run_name

 ## Data Loading

In [None]:
TARGETS = ALL_FEATURE_NAMES if TARGET_ALL_FEATURES else KEY_FEATURE_NAMES
COVARIATES = [item for item in ALL_FEATURE_NAMES if item not in TARGETS]

darts_ts: TimeSeriesContainer = load_company_fundamentals_as_darts_time_series(
    kind=CompanyFundamentalsKind.Normalized,
    subset=False,
    min_length="max",  # only use complete sequences for simplicity when comparing to other models
    include_static_metadata=USE_STATIC_COVARIATES,
    static_metadata_columns=["GICS_sector"],
    target_columns=TARGETS,
    covariate_columns=COVARIATES,
    n_jobs=4,  # Increasing this won't help much
).one_hot_encode_statics()

We evaluate by splitting the 40 quarterly datapoints into several folds.
We do this to evaluate the models on different time periods, which is important for forecasting, since the dynamics can change drastically over time.
In the extreme cases, this can be due to financial cirses, conflicts, ... but also just more stublte effects the interplay of debt vs. equity financing and reporting bahviour.

In [None]:
# This is just a dummy check. Also, we need twice the HORIZON_FORECAST for training + testing
assert HORIZON_LOOKBACK + 2 * HORIZON_FORECAST <= darts_ts.targets[0].n_timesteps

This will generate many slices of the data (depending on the number of time steps `darts_ts.targets[0].n_timesteps` and the `stride`), which we will use for cross-validation.

## Enumerate Experiments

In [None]:
from darts.models import (
    NBEATSModel,
    ARIMA,
    VARIMA,
    StatsForecastAutoARIMA,
    StatsForecastAutoTheta,
    Prophet,
    RandomForest,
    LinearRegressionModel,
    RNNModel,
    TCNModel,
    TransformerModel,
    TFTModel,
    DLinearModel,
    NLinearModel,
    NaiveMean,
    NaiveMovingAverage,
    NHiTSModel,
    TiDEModel,
    BlockRNNModel,
)
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.utils.likelihood_models import QuantileRegression
from forecasting_cfs.chronos import ChronosDartsWrapper
from forecasting_cfs.xlstm_mixer import xLSTMMixer

likelihood = "quantile" if UNCERTAINTY else None
likelihood_model = QuantileRegression if UNCERTAINTY else (lambda: None)


# See https://arxiv.org/pdf/1803.09820.pdf
# See https://unit8co.github.io/darts/examples/18-TiDE-examples.html
def my_early_stopper():
    return EarlyStopping(
        monitor="val_loss",
        patience=3,
        mode="min",
    )


def common_torch_wo_dropout():
    return dict(
        pl_trainer_kwargs=dict(
            accelerator="auto",
            devices="auto",
            gradient_clip_val=1.0,
            callbacks=[my_early_stopper()],
        ),
        optimizer_cls=optim.AdamW,
        optimizer_kwargs=dict(lr=0.0001, weight_decay=0.01),
        batch_size=64,
        n_epochs=100,
    )


def common_torch():
    return dict(**common_torch_wo_dropout(), dropout=0.1)


def make_models() -> dict[str, ForecastingModel]:
    models: dict[str, ForecastingModel] = {
        "Mean": NaiveMean(),
        "ARMean(1)": NaiveMovingAverage(input_chunk_length=1),
        "ARMean(4)": NaiveMovingAverage(input_chunk_length=4),
        "ARMA(1,1)": ARIMA(p=1, d=0, q=1, trend="ct"),
        "ARMA(4,4)": ARIMA(p=4, d=0, q=4, trend="ct"),
        "ARIMA(4,1,4)": ARIMA(p=4, d=1, q=4, trend="t"),
        "VARIMA(4,0,4)": VARIMA(p=4, d=0, q=4, trend="ct"),
        "VARIMA(4,1,4)": VARIMA(p=4, d=1, q=4),
        "AutoARIMA": StatsForecastAutoARIMA(),
        "AutoTheta": StatsForecastAutoTheta(),
        "Prophet": Prophet(),
        # ############################################
        "Linear Reg.": LinearRegressionModel(
            lags=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            lags_past_covariates=(
                HORIZON_LOOKBACK if not TARGET_ALL_FEATURES else None
            ),
            multi_models=True,
            use_static_covariates=USE_STATIC_COVARIATES,
            likelihood=likelihood,
        ),
        "Random Forest": RandomForest(
            lags=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            lags_past_covariates=HORIZON_LOOKBACK,
            use_static_covariates=USE_STATIC_COVARIATES,
        ),
        # ############################################
        "DLinear": DLinearModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            kernel_size=10,  # for calculating the moving average to remove trend
            use_static_covariates=USE_STATIC_COVARIATES,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch_wo_dropout(),
        ),
        "NLinear": NLinearModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            normalize=not UNCERTAINTY,
            use_static_covariates=USE_STATIC_COVARIATES,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch_wo_dropout(),
        ),
        "RNN (LSTM)": RNNModel(
            input_chunk_length=HORIZON_LOOKBACK,
            model="LSTM",
            # https://unit8co.github.io/darts/userguide/torch_forecasting_models.html#required-target-time-spans-for-training-validation-and-prediction
            training_length=HORIZON_LOOKBACK + 1,  # This is an autoregressive model
            hidden_dim=64,
            n_rnn_layers=3,
            # Should not be used with use_reversible_instance_norm=True
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "RNN (GRU)": RNNModel(
            # See "RNN (LSTM)" above for comments
            input_chunk_length=HORIZON_LOOKBACK,
            model="GRU",
            training_length=HORIZON_LOOKBACK + 1,
            hidden_dim=64,
            n_rnn_layers=3,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "Block RNN (LSTM)": BlockRNNModel(
            # See "RNN (LSTM)" above for comments
            input_chunk_length=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            model="LSTM",
            hidden_dim=128,
            n_rnn_layers=3,
            use_reversible_instance_norm=True,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "Block RNN (GRU)": BlockRNNModel(
            # See "RNN (LSTM)" above for comments
            input_chunk_length=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            model="GRU",
            hidden_dim=128,
            n_rnn_layers=3,
            use_reversible_instance_norm=True,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "TCN": TCNModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            num_filters=16,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "Transformer": TransformerModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            d_model=120,
            dim_feedforward=512,
            num_encoder_layers=4,
            num_decoder_layers=4,
            nhead=6,
            activation="gelu",
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "TFT": TFTModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            hidden_size=36,
            num_attention_heads=6,
            full_attention=True,
            add_relative_index=True,  # For the forecast without covariates
            loss_fn=None if UNCERTAINTY else torch.nn.MSELoss(),
            use_static_covariates=USE_STATIC_COVARIATES,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "N-BEATS": NBEATSModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            layer_widths=512,
            num_layers=6,
            num_blocks=1,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **(common_torch() | dict(batch_size=256)),
        ),
        "N-HiTS": NHiTSModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "TiDE": TiDEModel(
            HORIZON_LOOKBACK,
            HORIZON_FORECAST,
            num_encoder_layers=2,
            num_decoder_layers=2,
            use_static_covariates=USE_STATIC_COVARIATES,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch(),
        ),
        "xLSTM-Mixer": xLSTMMixer(
            input_chunk_length=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            use_reversible_instance_norm=USE_REVIN,
            likelihood=likelihood_model(),
            **common_torch_wo_dropout(),
        ),
        # ############################################
        # See: https://huggingface.co/autogluon/chronos-bolt-base
        "Chronos-Bolt-Small": ChronosDartsWrapper(
            model_name="amazon/chronos-bolt-small"
        ),
        "Chronos-Bolt-Base": ChronosDartsWrapper(model_name="amazon/chronos-bolt-base"),
    }
    if UNCERTAINTY:
        # As of darts version 0.27.2
        SUPPORTS_UNCERTAINTY = {
            ARIMA,
            VARIMA,
            StatsForecastAutoARIMA,
            # ExponentialSmoothing,
            # StatsforecastAutoETS,
            # BATS,
            # TBATS,
            StatsForecastAutoTheta,
            Prophet,
            # KalmanForecaster,
            LinearRegressionModel,
            # LightGBMModel,
            # XGBModel,
            # CatBoostModel,
            GlobalForecastingModel,  # All of them
        }
        return {
            name: model
            for name, model in models.items()
            if (
                any(isinstance(model, klass) for klass in SUPPORTS_UNCERTAINTY)
                # RandomForest is a GlobalForecastingModel:
                and not isinstance(model, RandomForest)
            )
        }
    return models


In [None]:
len(make_models())

## Run Experiments

### Fit models & Eval

In [None]:
for cross_validation_fold_index, data_entry in enumerate(
    generate_splits(
        darts_ts,
        horizon_lookback=HORIZON_LOOKBACK,
        horizon_forecast=HORIZON_FORECAST,
        stride=STRIDE,
        fraction_validation=0.1,  # We don't have much data
        window_mode=WINDOW_MODE,
    )
):
    train_ts, val_ts, test_ts = data_entry

    if only_from_to is not None and cross_validation_fold_index not in range(
        *only_from_to
    ):
        continue

    cv_path = OUT_PATH / str(cross_validation_fold_index)

    result_data_path = cv_path / "result_data"
    result_data_path.mkdir(exist_ok=True, parents=True)

    example_predictions_path = result_data_path / "example_predictions.pkl"
    if example_predictions_path.exists():
        print(f"Skipping {cross_validation_fold_index} as it already exists")
        continue

    models = make_models()

    with warnings.catch_warnings():
        # For ARIMA-family models, we get a lot of warnings about convergence issues; but also for other models
        warnings.filterwarnings(
            "ignore",
            category=UserWarning,
            module="statsmodels.+",
            append=True,
        )
        warnings.filterwarnings(
            "ignore",
            category=UserWarning,
            module="torch.nn.modules.transformer",
            append=True,
        )

        with ModelEvaluator(
            train_ts=train_ts,
            val_ts=val_ts,
            test_ts=test_ts,
            input_chunk_length=HORIZON_LOOKBACK,
            output_chunk_length=HORIZON_FORECAST,
            checkpoint_path=cv_path / "checkpoints",
            num_samples=100 if UNCERTAINTY else 1,
            group_name="wandb_group_name",
            job_type=f"fold-{cross_validation_fold_index}",
            # evaluate_first_n_companies=20,
            # return_first_n_companies=20,
            num_processes_local_models=8,
        ) as model_evaluator:
            results = list(model_evaluator.eval_in_parallel(models, num_workers=8))

            # Note: Errors such as "ValueError: Model `StatsForecastAutoARIMA` only supports univariate TimeSeries"
            # come from the subprocesses where logging is not disabled. This is just darts being noisy.

        print(
            f"{len(results)} done out of {len(models)} models, the remaining {len(models) - len(results)} had errors"
        )

        metrics_for_pandas = [
            (result.metrics | dict(model=result.name)) for result in results
        ]
        df = pd.DataFrame(metrics_for_pandas).set_index("model")
        df.sort_index(inplace=True)
        df = df.melt(var_name="metric", ignore_index=False)
        df.to_pickle(result_data_path / "results.pkl")

        example_predictions: dict[str, list[ForecastingResult]] = dict(
            SortedDict({result.name: result.test_forecasts for result in results})
        )
        with open(example_predictions_path, "wb") as f:
            pickle.dump(example_predictions, f, pickle.HIGHEST_PROTOCOL)

        example_predictions_slim: dict[str, dict[int, TimeSeries]] = {
            result.name: {
                fc.meta_data["companyid"]: fc.ts_forecast
                for fc in result.test_forecasts
            }
            for result in results
        }
        with open(result_data_path / "example_predictions_slim.pkl", "wb") as f:
            pickle.dump(example_predictions_slim, f, pickle.HIGHEST_PROTOCOL)

        print(f"Done with cross-validation fold #{cross_validation_fold_index}")
        print("=" * 60)

## Explainability

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from darts.explainability import TFTExplainer

from proprietary_data import companyid_to_name

sns.set_style("whitegrid")
sns.set_palette("colorblind")

In [None]:
id_to_name = companyid_to_name(subset=False, min_length="max")

In [None]:
explainability_path = (
    Path("forecast_baselines")
    / "statics_True-revin_True-stride1_lb12-expanding-uq_True"
    / "0"
)
model = TFTModel.load(str(explainability_path / "checkpoints" / "TFT.ckpt"))
model

In [None]:
train_ts, val_ts, test_ts = next(
    iter(
        generate_splits(
            darts_ts, horizon_lookback=12, horizon_forecast=4, window_mode="expanding"
        )
    )
)

In [None]:
company_index = 520
(
    id_to_name[train_ts.meta_data[company_index]["companyid"]],
    test_ts.targets[company_index].start_time(),
)

In [None]:
explainer = TFTExplainer(
    model,
    background_series=train_ts.targets[company_index],
    background_past_covariates=train_ts.covariates[company_index],
)
explainability_result = explainer.explain(
    # foreground_series=train_ts.targets[company_index],
    # foreground_past_covariates=train_ts.covariates[company_index],
)
explainability_result

In [None]:
explainer.plot_variable_selection(explainability_result)

In [None]:
attention = explainability_result.get_attention().mean(axis=1)

train_ice_transformed = train_ts.targets[company_index]

time_intersection = train_ice_transformed.time_index.intersection(attention.time_index)

train_ice_transformed[time_intersection].plot()
attention.plot(label="mean_attention", max_nr_components=12)

pass

In [None]:
# Collect for all splits

results = {"encoder": [], "static": []}

for split_id, split_data in enumerate(
    track(
        generate_splits(
            darts_ts, horizon_lookback=12, horizon_forecast=4, window_mode="expanding"
        ),
        total=40,
    )
):
    train_ts, val_ts, test_ts = split_data

    model = TFTModel.load(
        str(explainability_path.parent / str(split_id) / "checkpoints" / "TFT.ckpt")
    )
    explainer = TFTExplainer(
        model,
        background_series=train_ts.targets[company_index],
        background_past_covariates=train_ts.covariates[company_index],
    )
    explainability_result = explainer.explain()

    encoder_imp = explainability_result.get_encoder_importance()
    static_imp = explainability_result.get_static_covariates_importance()

    results["encoder"].append(encoder_imp)
    results["static"].append(static_imp)

encoder_imp_sum = sum(results["encoder"]) / len(results["encoder"])
static_imp_sum = sum(results["static"]) / len(results["static"])

In [None]:
# encoder_imp = explainability_result.get_encoder_importance()
# static_imp = explainability_result.get_static_covariates_importance()

encoder_imp = encoder_imp_sum.reindex(encoder_imp_sum.sum().sort_values().index, axis=1)
static_imp = static_imp_sum.reindex(static_imp_sum.sum().sort_values().index, axis=1)

# Make it a bit more readable
encoder_imp = encoder_imp[encoder_imp.columns[encoder_imp.max() > 3]]
static_imp = static_imp[static_imp.columns[static_imp.max() > 10]]

encoder_imp.columns = (
    encoder_imp.columns.str.replace("_pastcov", " (Past Covariate)")
    .str.replace("_target", " (Past Target)")
    .str.replace(
        "add_relative_index_futcov", "Synthetic Relative Index (Future Covariate)"
    )
    .str.replace("Short Term Investments", "Short Term Inv.")
)

static_imp.columns = (
    static_imp.columns.str.replace("_statcov", "")
    .str.replace("_", " ")
    .str.replace("GICS sector ", "GICS sector: ")
)

uses_static_covariates = not static_imp.empty

# plot the encoder and decoder weights
fig, axes = plt.subplots(
    nrows=2 if uses_static_covariates else 1,
    sharex=True,
    squeeze=False,
    subplot_kw=dict(aspect=0.45),
)
axes = axes.flatten()
explainer._plot_cov_selection(encoder_imp, title="", ax=axes[0])
# axes[0].xaxis.set_ticks_position("top")
axes[0].xaxis.set_label_position("top")
fontsize = 10  # Reset the font size
axes[0].set_xlabel(axes[0].get_xlabel(), fontsize=fontsize)
axes[0].set_ylabel("Covariates", fontsize=fontsize)

if uses_static_covariates:
    explainer._plot_cov_selection(
        static_imp,
        title="",
        ax=axes[1],
    )
    axes[1].xaxis.set_label_position("top")
    axes[1].set_xlabel("")
    axes[1].set_ylabel("Static Variables", fontsize=fontsize)


# move them toghether
plt.subplots_adjust(hspace=-0.4)
fig.align_ylabels(axes)
sns.despine(top=True, bottom=True)

path = explainability_path / "tft_explainability.pdf"
plt.savefig(path, bbox_inches="tight")
plt.show()
plt.close()

str(path.absolute())