In [None]:
# need to restart notebook in order to get numpy 1.x.x after installing it ...
!pip install numpy==1.26.4
import numpy as np
if int(np.__version__[0]) > 1:
  import os
  os.kill(os.getpid(), 9)

In [None]:
!pip install timesfm
!pip install sktime
!pip install chronos-forecasting
!pip install tqdm_joblib
!pip install u8darts==0.34.0
!pip install statsforecast
!pip install pytorch_lightning
!pip install pytorch-forecasting

In [None]:
import math
import pandas as pd
import matplotlib.pyplot as plt
from sktime.datasets import load_tsf_to_dataframe, load_forecastingdata, load_fpp3
from sktime.split import ExpandingSlidingWindowSplitter, ExpandingWindowSplitter, TemporalTrainTestSplitter
from sktime.transformations.series.difference import Differencer
from sktime.forecasting.statsforecast import StatsForecastAutoARIMA
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.pytorchforecasting import PytorchForecastingNBeats
from sktime.forecasting.darts import DartsXGBModel
from sklearn.metrics import mean_absolute_percentage_error
import torch
import timesfm
from chronos import BaseChronosPipeline
from joblib import Parallel, delayed
from tqdm.notebook import tqdm
from tqdm_joblib import tqdm_joblib
import warnings
import time
from enum import Enum

In [6]:
class EvalMode(Enum):
  TRAIN_TEST_SPLIT = 0
  EXPANDING_WINDOW = 1

In [7]:
class ModelType(Enum):
  TimesFM = 0
  TimesFM2 = 1
  Chronos = 2
  AutoARIMA = 3
  XGBoost = 4
  Baseline = 5
  NBeats = 6

In [8]:
pad_dims_fn = lambda eval_mode, maxlen, x : (maxlen - len(x), 0) if eval_mode == EvalMode.TRAIN_TEST_SPLIT else (0, maxlen - len(x))

In [9]:
ds_cache = {}

In [10]:
def load_monash_dataset(name):
  if name in ds_cache:
    df, metadata = ds_cache[name]
    return df.copy(), metadata
  df, metadata = load_forecastingdata(name)
  ds_cache[name] = (df, metadata)
  return df.copy(), metadata

In [11]:
def df_to_ndarray(df, horizon_length, initial_context_length, pad_dims_fn, eval_mode, value_column='series_value', length_column='series_length'):
  max_series_length = max(df[length_column])
  # if using single train test split, we don't have initial context, just pad to max length
  padded_length = max_series_length if eval_mode == EvalMode.TRAIN_TEST_SPLIT else math.ceil((max_series_length - initial_context_length) / horizon_length) * horizon_length + initial_context_length
  df[value_column] = df[value_column].apply(lambda x: np.pad(x, pad_dims_fn(eval_mode, padded_length, x), constant_values=0))
  return np.vstack(df[value_column])


In [12]:
def init_monash_dataset(name, horizon_length, initial_context_length, pad_dims_fn, eval_mode=EvalMode.TRAIN_TEST_SPLIT, plot=False, max_series_length = None, max_length = None):
  df, metadata = load_monash_dataset(name)
  if(max_length is not None):
    df = df.head(max_length)

  if max_series_length is not None:
    df['series_value'] = df['series_value'].apply(lambda x: x[:max_series_length])

  df['series_length'] = df['series_value'].apply(lambda x: len(x))

  if plot:
    plt.hist(df['series_length'])
    plt.xlabel('Series Length')
    plt.ylabel('Frequency')
    plt.show()

  dataset = df_to_ndarray(df, horizon_length, initial_context_length, pad_dims_fn, eval_mode)
  print(name + ", shape: " + str(dataset.shape))
  print("Metadata:", metadata)
  return dataset

In [13]:
def load_etth(horizon_length, initial_context_length, pad_dims_fn, eval_mode=EvalMode.TRAIN_TEST_SPLIT, max_series_length=None):
  df1 = pd.read_csv('https://raw.githubusercontent.com/zhouhaoyi/ETDataset/refs/heads/main/ETT-small/ETTh1.csv')
  df2 = pd.read_csv('https://raw.githubusercontent.com/zhouhaoyi/ETDataset/refs/heads/main/ETT-small/ETTh2.csv')
  df = pd.DataFrame({'series_value': [df1['OT'].tolist(), df2['OT'].tolist()] })
  if max_series_length is not None:
    df['series_value'] = df['series_value'].apply(lambda x: x[:max_series_length])

  df['series_length'] = df['series_value'].apply(lambda x: len(x))

  dataset = df_to_ndarray(df, horizon_length, initial_context_length, pad_dims_fn, eval_mode)
  print("ETTh, shape: " + str(dataset.shape))
  return dataset

In [14]:
def get_timesfm_model(horizon_length, v1=True):
  if v1:
    return timesfm.TimesFm(
        hparams=timesfm.TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,
          horizon_len=horizon_length,
        ),
        checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
    )
  else:
    return timesfm.TimesFm(
        hparams=timesfm.TimesFmHparams(
            backend="gpu",
            per_core_batch_size=32,
            horizon_len=horizon_length,
            num_layers=50,
            use_positional_embedding=False,
            context_len=2048,
        ),
        checkpoint=timesfm.TimesFmCheckpoint(
            huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
  )


In [15]:
def get_chronos_model(name='amazon/chronos-bolt-base'):
  return BaseChronosPipeline.from_pretrained(
    name,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

In [16]:
def get_nbeats_model(horizon_length):
  return PytorchForecastingNBeats(
    trainer_params={
        "max_epochs": 1
    },
    #recommended params for generic mode
    model_params={
        "prediction_length": horizon_length,
        "stack_types": ["generic"],
        "num_blocks": [1],
        "num_block_layers": [4],
        "widths": [512],
        "sharing": False,
        "expansion_coefficient_lengths": [32]}
    )

In [17]:
def get_timesfm_predict_fn(horizon_length, frequency, v1=True):
  model = get_timesfm_model(horizon_length, v1)
  def _predict(context):
    point_forecast, experimental_quantile_forecast = model.forecast(
        context,
        freq=np.repeat(frequency, len(context)))
    return point_forecast
  return _predict

In [18]:
def get_chronos_predict_fn(horizon_length, model_name='amazon/chronos-bolt-base'):
  model = get_chronos_model(model_name)
  def _predict(context):
    quantiles, mean = model.predict_quantiles(
        context=torch.from_numpy(context),
        prediction_length=horizon_length,
        quantile_levels=[0.5]
    )
    return quantiles[:, :, 0].numpy()
  return _predict

In [19]:
def get_xgb_predict_fn(horizon_length, max_lags):
  def _predict(context):
    lags = max_lags if max_lags < (context.shape[1]-1) else (context.shape[1]-1)
    model = DartsXGBModel(lags=lags)
    forecast = model.fit_predict(np.expand_dims(context, axis=1), fh=np.arange(1, horizon_length + 1))
    return forecast[:, 0, :]
  return _predict

In [20]:
def get_sktime_forecaster_parallel_predict_fn(model, horizon_length):
  def _predict(context):
    def _fit_predict(y):
      model.reset()
      return model.fit_predict(y, fh=np.arange(1, horizon_length + 1))

    delayedFunc = delayed(_fit_predict)
    with tqdm_joblib(tqdm(desc="Processing", total=context.shape[0])):
      forecast = Parallel(n_jobs=-1, verbose=10)(delayedFunc(context[i, :]) for i in range(context.shape[0]))
    return np.array(forecast).squeeze(axis=2)
  return _predict

In [21]:
def get_sktime_forecaster_predict_fn(model, horizon_length):
  def _predict(context):
    model.reset()
    forecast = model.fit_predict(np.expand_dims(context, axis=1), fh=np.arange(1, horizon_length + 1))
    return forecast[:, 0, :]
  return _predict

In [22]:
# If pre-padding is used this returns only True which is fine when calculating the errors, as there is no padding in the forecasted horizon we need to exclude
def get_post_padding_mask(values, padding_value=0):
    last_non_zero_indices = np.max((values != 0) * np.arange(values.shape[1]), axis=1)
    col_indices = np.arange(values.shape[1])
    return col_indices[None, :] <= last_non_zero_indices[:, None]

In [23]:
def rmse_with_padding(predictions, truth, padding_value=0):
  weights = get_post_padding_mask(truth, padding_value).astype(float)
  weighted_squared_errors = weights * ((predictions - truth) ** 2)
  return np.sqrt(np.sum(weighted_squared_errors) / np.sum(weights))

In [24]:
def mae_with_padding(predictions, truth, padding_value=0):
  mask = get_post_padding_mask(truth, padding_value)
  absolute_errors = np.abs(truth - predictions)
  masked_absolute_errors = absolute_errors[mask]
  return np.mean(masked_absolute_errors)

In [25]:
def weighted_mape_with_padding(predictions, truth, padding_value=0):
  mask = get_post_padding_mask(truth, padding_value)
  numerator = np.sum(np.abs(truth - predictions)[mask])
  denominator = np.sum(truth[mask])

  if denominator == 0:
        return np.nan
  return (numerator / denominator) * 100

In [26]:
def calc_metrics(actual, predictions):
  mae = mae_with_padding(predictions, actual)
  wmape = weighted_mape_with_padding(predictions, actual)
  rmse = rmse_with_padding(predictions, actual)
  return { 'mae': mae, 'wmape': wmape, 'rmse': rmse }

In [27]:
def predict(dataset, predict_fn, splits_indices, horizon_length):
  predictions = np.empty(shape=(len(dataset), 0))
  partial_metrics = []
  for context_idx, horizon_idx in splits_indices:
    context = dataset[:, context_idx]
    horizon = dataset[:, horizon_idx]
    forecast = predict_fn(context)

    partial_metrics.append(calc_metrics(horizon, forecast))
    predictions = np.append(predictions, forecast, axis=1)
  return predictions, partial_metrics

In [28]:
def infer(dataset, splitter, predict_fn, horizon_length, context_length):
  splits = list(splitter.split(dataset[0]))

  start_time = time.time()
  predictions, partial_metrics = predict(dataset, predict_fn, splits, horizon_length)
  elapsed_time = str(round(time.time() - start_time, 1))
  metrics = calc_metrics(dataset[:, context_length:], predictions)
  return { 'predictions': predictions, 'metrics': metrics, 'partial_metrics': partial_metrics, 'time': elapsed_time }

In [29]:
def get_ds_splitter(horizon_length, eval_mode=EvalMode.TRAIN_TEST_SPLIT, initial_context_length=0):
  return TemporalTrainTestSplitter(test_size=horizon_length) if eval_mode == EvalMode.TRAIN_TEST_SPLIT \
    else ExpandingWindowSplitter(fh=np.arange(1, horizon_length + 1), step_length=horizon_length, initial_window=initial_context_length)

In [30]:
def infer_on_dataset(dataset, eval_mode, horizon_length, models, context_length=0, timesfm_freq=0, seasonal_period=1, xgb_lags=1):
  splitter = get_ds_splitter(horizon_length, eval_mode, context_length)

  model_to_infer_fn = {
      ModelType.TimesFM: lambda: get_timesfm_predict_fn(horizon_length, frequency=timesfm_freq, v1=True),
      ModelType.TimesFM2: lambda: get_timesfm_predict_fn(horizon_length, frequency=timesfm_freq, v1=False),
      ModelType.Chronos: lambda: get_chronos_predict_fn(horizon_length),
      ModelType.AutoARIMA: lambda: get_sktime_forecaster_parallel_predict_fn(StatsForecastAutoARIMA(sp=seasonal_period), horizon_length), # Using the parallel approach for every TS gives faster inference time for AutoARIMA.
      ModelType.XGBoost: lambda: get_xgb_predict_fn(horizon_length, max_lags=xgb_lags),
      ModelType.NBeats: lambda: get_sktime_forecaster_predict_fn(get_nbeats_model(horizon_length), horizon_length),
      ModelType.Baseline: lambda: get_sktime_forecaster_predict_fn(NaiveForecaster(sp=seasonal_period), horizon_length)
  }

  model_infer_fns = []
  for model in models:
    model_infer_fns.append((model.name, model_to_infer_fn[model]()))

  results = {}
  for model_name, predict_fn in model_infer_fns:
    print("Infer using " + model_name + "...")
    with warnings.catch_warnings():
      warnings.filterwarnings("ignore")
      results[model_name] = infer(dataset, splitter, predict_fn, horizon_length, context_length)

  return results

In [31]:
def plot_series(dataset, predictions, context_length, horizon_length, series_indices, cols=3, row_size=3, plotted_context_length=0):
  rows = math.ceil(len(series_indices) / cols)
  fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))

  # Ensure axes is always 2D
  if cols == 1:
    axes = np.atleast_2d(axes).T  # Convert to column vector
  if rows == 1:
    axes = np.atleast_2d(axes)  # Convert to row vector


  for i in range(0, rows):
    for j in range(0, cols):
      idx = i * cols + j
      if idx >= len(series_indices):
        axes[i, j].axis('off')
        continue

      series_idx = series_indices[idx]
      truth = dataset[series_idx][context_length:]

      padding_mask = get_post_padding_mask(truth.reshape(1, -1)) # Only Trues for pre-padding (expanding window eval) so no side-effects if train-test split.
      non_padded_values_idx = np.nonzero(padding_mask[0])[0]
      plotted_truth = truth[non_padded_values_idx]

      if plotted_context_length > 0:
          initial_ctx = dataset[series_idx][context_length-plotted_context_length:context_length]
          # When using pre-padding (train-test eval) we need to trim the leading zeros if any.
          initial_ctx = initial_ctx[np.argmax(initial_ctx != 0):]
          plotted_context_length = len(initial_ctx)
          plotted_truth = np.concatenate((initial_ctx, truth[non_padded_values_idx])) # get the initial ctx as well

      axes[i, j].plot(plotted_truth, label="Actual")

      for name, model_predictions in predictions.items():
        if name == 'Baseline':
          continue

        prediction = model_predictions[series_idx]
        plotted_prediction = np.pad(prediction[non_padded_values_idx], (plotted_context_length, 0), constant_values=np.nan) # pad the initial ctx
        axes[i, j].plot(plotted_prediction, label=name, alpha=0.7)

      # Vertical lines for horizons
      ymin, ymax = axes[i, j].get_ylim()
      axes[i, j].vlines(x=np.arange(start=context_length, step=horizon_length, stop=len(plotted_truth)), ymin=ymin, ymax=ymax, colors='lightblue', ls='--', alpha=0.7)

      axes[i, j].set_title("Series " + str(series_idx + 1))
      axes[i, j].legend()
  plt.tight_layout()
  plt.show()

In [32]:
def run(experiments):
  results = {}
  for name, experiment in experiments.items():
    print("Running experiment '{}'..".format(name))
    dataset = experiment['dataset_init_fn'](experiment['eval_mode'], experiment['horizon_length'], experiment['initial_context_length'])

    context_length = len(dataset[0]) - experiment['horizon_length'] if experiment['eval_mode'] == EvalMode.TRAIN_TEST_SPLIT else experiment['initial_context_length']

    result = infer_on_dataset(dataset,
                              experiment['eval_mode'],
                              experiment['horizon_length'],
                              experiment['models'],
                              context_length,
                              experiment['timesfm_freq'],
                              experiment['seasonal_period'],
                              experiment.get('xgb_lags', experiment['seasonal_period']))

    results[name] = {'dataset': dataset, 'context_length': context_length, 'model_results': result}
  return results


In [33]:
def plot_partial_metrics(model_results):
  for model_name, model_result in model_results.items():
    partial_metrics = model_result['partial_metrics']
    df = pd.DataFrame(partial_metrics)
    df.plot(subplots=True, layout=(3, 1), figsize=(8, 6), marker='o', title=f'Error terms per fold for {model_name}', sharex=False)
    plt.tight_layout()
    plt.show()

In [34]:
def plot_experiment_result(exp_name, exp_result, exp_config):
  print("Experiment '{}'".format(exp_name))
  predictions = {}

  baseline = exp_result['model_results']['Baseline']

  for model_name, model_result in exp_result['model_results'].items():

    predictions[model_name] = model_result['predictions']
    print(model_name + ":")
    print("MAE: ", model_result['metrics']['mae'])
    print("RMSE: ", model_result['metrics']['rmse'])
    print("MASE: ", model_result['metrics']['mae'] / baseline['metrics']['mae'])
    print("wMAPE: ", model_result['metrics']['wmape'])
    print("Time: ", model_result['time'])
    print("\n")

  plot_config = exp_config.get('plot_config')
  if plot_config is not None and plot_config.get('plot', True):
    plot_series(
        dataset=exp_result['dataset'],
        predictions=predictions,
        context_length=exp_result['context_length'],
        horizon_length=exp_config['horizon_length'],
        series_indices=plot_config['series_indices'],
        cols=plot_config.get('cols', 3),
        plotted_context_length=plot_config.get('plotted_context_length', exp_result['context_length']))

  if plot_config is not None and plot_config.get('plot_partial_metrics', False):
    plot_partial_metrics(exp_result['model_results'])


In [35]:
experiments = {
    'us_births_dataset_tts_30': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
        ],
        'horizon_length': 30,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 7,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("us_births_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_length=4096),
        'plot_config': {
            'plot': True,
            'plot_partial_metrics': False,
            'series_indices': [0],
            'cols': 1,
            'plotted_context_length': 256
        }
    },
    'us_births_dataset_ew_30': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
        ],
        'horizon_length': 30,
        'initial_context_length': 2048, # Ignored for train-test split
        'timesfm_freq': 0,
        'xgb_lags': 2048,
        'seasonal_period': 7,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("us_births_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_length=4096),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': False,
            'series_indices': [0],
            'cols': 1,
            'plotted_context_length': 256
        }
    },
    'us_births_dataset_tts_365': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
        ],
        'horizon_length': 365,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 7,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("us_births_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_length=4096),
        'plot_config': {
            'plot': True,
            'plot_partial_metrics': False,
            'series_indices': [0],
            'cols': 1,
            'plotted_context_length': 256
        }
    },
    'us_births_dataset_ew_365': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
        ],
        'horizon_length': 365,
        'initial_context_length': 2048, # Ignored for train-test split
        'timesfm_freq': 0,
        'xgb_lags': 2048,
        'seasonal_period': 7,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("us_births_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_length=4096),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': False,
            'series_indices': [0],
            'cols': 1,
            'plotted_context_length': 256
        }
    },
    'fred_md_dataset_tts_12': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 12,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 1,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("fred_md_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot': False,
            'plot_partial_metrics': False,
            'series_indices': np.arange(107),
            'cols': 2
        }
    },
    'fred_md_dataset_ew_12': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 12,
        'initial_context_length': 256,
        'timesfm_freq': 1,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("fred_md_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': False,
            'series_indices': np.arange(107),
            'cols': 2
        }
    },
    'fred_md_dataset_tts_60': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 60,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 1,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("fred_md_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot': True,
            'plot_partial_metrics': False,
            'series_indices': np.arange(107),
            'plotted_context_length': 120,
            'cols': 2
        }
    },
    'fred_md_dataset_ew_60': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 60,
        'initial_context_length': 256,
        'timesfm_freq': 1,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("fred_md_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': True,
            'series_indices': np.arange(107),
            'cols': 2
        }
    },
    'etth1_tts_96': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 96,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 24,
        'xgb_lags': 24,
        'dataset_init_fn': lambda em, hl, icl: load_etth(horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_series_length=4096),
        'plot_config': {
            'plot': True,
            'plot_partial_metrics': False,
            'series_indices': np.arange(2),
            'plotted_context_length': 192,
            'cols': 1
        }
    },
    'etth1_ew_96': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 96,
        'initial_context_length': 2048, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 24,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: load_etth(horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_series_length=4096),
        'plot_config': {
            'plot': False,
            'plot_partial_metrics': False,
            'series_indices': np.arange(2),
            'plotted_context_length': 120,
            'cols': 1
        }
    },
    'etth1_tts_192': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 192,
        'initial_context_length': 0, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 24,
        'xgb_lags': 24,
        'dataset_init_fn': lambda em, hl, icl: load_etth(horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_series_length=4096),
        'plot_config': {
            'plot': True,
            'plot_partial_metrics': False,
            'series_indices': np.arange(2),
            'plotted_context_length': 480,
            'cols': 1
        }
    },
    'etth1_ew_192': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 192,
        'initial_context_length': 2048, # Ignored for train-test split
        'timesfm_freq': 0,
        'seasonal_period': 24,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: load_etth(horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em, max_series_length=4096),
        'plot_config': {
            'plot': False,
            'plot_partial_metrics': False,
            'series_indices': np.arange(2),
            'plotted_context_length': 120,
            'cols': 1
        }
    },
    'covid_deaths_dataset_tts_7': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 7,
        'initial_context_length': 0,
        'timesfm_freq': 0,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("covid_deaths_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': True,
            'series_indices': np.arange(20),
            'cols': 2
        }
    },
    'covid_deaths_dataset_ew_7': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 7,
        'initial_context_length': 96,
        'timesfm_freq': 0,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("covid_deaths_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': False,
            'series_indices': np.arange(200),
            'cols': 2
        }
    },
    'covid_deaths_dataset_tts_30': {
        'eval_mode': EvalMode.TRAIN_TEST_SPLIT,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 30,
        'initial_context_length': 0,
        'timesfm_freq': 0,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("covid_deaths_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': True,
            'series_indices': np.arange(20),
            'cols': 2
        }
    },
    'covid_deaths_dataset_ew_30': {
        'eval_mode': EvalMode.EXPANDING_WINDOW,
        'models': [
            ModelType.TimesFM,
            ModelType.TimesFM2,
            ModelType.Chronos,
            ModelType.AutoARIMA,
            ModelType.XGBoost,
            ModelType.NBeats,
            ModelType.Baseline
          ],
        'horizon_length': 30,
        'initial_context_length': 96,
        'timesfm_freq': 0,
        'seasonal_period': 1,
        'xgb_lags': 2048,
        'dataset_init_fn': lambda em, hl, icl: init_monash_dataset("covid_deaths_dataset", horizon_length=hl, initial_context_length=icl, pad_dims_fn=pad_dims_fn, eval_mode=em),
        'plot_config': {
            'plot_partial_metrics': False,
            'plot': False,
            'series_indices': np.arange(212),
            'cols': 2
        }
    },
}

In [None]:
import logging
logging.getLogger("darts").setLevel(logging.ERROR)

r = run(experiments)
for exp_name, exp_result in r.items():
  plot_experiment_result(exp_name, exp_result, experiments[exp_name])