# Quick Start: Running Chronos and Chronos-Bolt models on gift-eval benchmark

This notebook shows how to run Chronos and Chronos-Bolt models on the gift-eval benchmark.

Make sure you download the gift-eval benchmark and set the `GIFT-EVAL` environment variable correctly before running this notebook.

We will use the `Dataset` class to load the data and run the model. If you have not already please check out the [dataset.ipynb](./dataset.ipynb) notebook to learn more about the `Dataset` class. We are going to just run the model on two datasets for brevity. But feel free to run on any dataset by changing the `short_datasets` and `med_long_datasets` variables below.

Install Chronos package:
``
pip install chronos-forecasting
``

In [3]:
import json
import pandas as pd

from dotenv import load_dotenv
from pathlib import Path

load_dotenv()
split_name = "train_test"
info_path = Path("resources") / split_name / "info.csv"

df = pd.read_csv(info_path)

prop_path = Path("notebooks") / "dataset_properties.json"
dataset_properties_map = json.load(open(prop_path))

In [4]:
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(
        quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ),
]

## Chronos Predictor

For foundation models, we need to implement a wrapper containing the model and use the wrapper to generate predicitons.

This is just meant to be a simple wrapper to get you started, feel free to use your own custom implementation to wrap any model.

In [None]:
from dataclasses import dataclass, field
from typing import List, Optional

import numpy as np
import torch
from chronos import BaseChronosPipeline, ForecastType
from gluonts.itertools import batcher
from gluonts.model import Forecast
from gluonts.model.forecast import QuantileForecast, SampleForecast
from tqdm.auto import tqdm


@dataclass
class ModelConfig:
    quantile_levels: Optional[List[float]] = None
    forecast_keys: List[str] = field(init=False)
    statsforecast_keys: List[str] = field(init=False)
    intervals: Optional[List[int]] = field(init=False)

    def __post_init__(self):
        self.forecast_keys = ["mean"]
        self.statsforecast_keys = ["mean"]
        if self.quantile_levels is None:
            self.intervals = None
            return

        intervals = set()

        for quantile_level in self.quantile_levels:
            interval = round(200 * (max(quantile_level, 1 - quantile_level) - 0.5))
            intervals.add(interval)
            side = "hi" if quantile_level > 0.5 else "lo"
            self.forecast_keys.append(str(quantile_level))
            self.statsforecast_keys.append(f"{side}-{interval}")

        self.intervals = sorted(intervals)


class ChronosPredictor:
    def __init__(
        self,
        model_path,
        num_samples: int,
        prediction_length: int,
        *args,
        **kwargs,
    ):
        self.pipeline = BaseChronosPipeline.from_pretrained(
            model_path,
            *args,
            **kwargs,
        )
        self.prediction_length = prediction_length
        self.num_samples = num_samples

    def predict(self, test_data_input, batch_size: int = 1024) -> List[Forecast]:
        pipeline = self.pipeline
        predict_kwargs = (
            {"num_samples": self.num_samples}
            if pipeline.forecast_type == ForecastType.SAMPLES
            else {}
        )
        while True:
            try:
                # Generate forecast samples
                forecast_outputs = []
                for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
                    context = [torch.tensor(entry["target"]) for entry in batch]
                    forecast_outputs.append(
                        pipeline.predict(
                            context,
                            prediction_length=self.prediction_length,
                            **predict_kwargs,
                        ).numpy()
                    )
                forecast_outputs = np.concatenate(forecast_outputs)
                break
            except torch.cuda.OutOfMemoryError:
                print(
                    f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}"
                )
                batch_size //= 2

        # Convert forecast samples into gluonts Forecast objects
        forecasts = []
        for item, ts in zip(forecast_outputs, test_data_input):
            forecast_start_date = ts["start"] + len(ts["target"])

            if pipeline.forecast_type == ForecastType.SAMPLES:
                forecasts.append(
                    SampleForecast(samples=item, start_date=forecast_start_date)
                )
            elif pipeline.forecast_type == ForecastType.QUANTILES:
                forecasts.append(
                    QuantileForecast(
                        forecast_arrays=item,
                        forecast_keys=list(map(str, pipeline.quantiles)),
                        start_date=forecast_start_date,
                    )
                )

        return forecasts

  from .autonotebook import tqdm as notebook_tqdm


## Evaluation

Now that we have our predictor class, we can use it to predict on the gift-eval benchmark datasets. We will use the `evaluate_model` function to evaluate the model. This function is a helper function to evaluate the model on the test data and return the results in a dictionary. We are going to follow the naming conventions explained in the [README](../README.md) file to store the results in a csv file called `all_results.csv` under the `results/chronos` folder.

The first column in the csv file is the dataset config name which is a combination of the dataset name, frequency and the term:

```python
f"{dataset_name}/{freq}/{term}"
```


In [6]:
import logging


class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(
    WarningFilter("The mean prediction is not stored in the forecast data")
)

In [7]:
import csv
from pathlib import Path
from gluonts.model import evaluate_model
from gift_eval.data import Dataset

model_name = "chronos_bolt_base"

output_dir = Path("..") / "results" / model_name / split_name
output_dir.mkdir(parents=True, exist_ok=True)
output_file = "all_results.csv"

csv_path = output_dir / output_file

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(
        [
            "dataset",
            "model",
            "eval_metrics/MSE[mean]",
            "eval_metrics/MSE[0.5]",
            "eval_metrics/MAE[0.5]",
            "eval_metrics/MASE[0.5]",
            "eval_metrics/MAPE[0.5]",
            "eval_metrics/sMAPE[0.5]",
            "eval_metrics/MSIS",
            "eval_metrics/RMSE[mean]",
            "eval_metrics/NRMSE[mean]",
            "eval_metrics/ND[0.5]",
            "eval_metrics/mean_weighted_sum_quantile_loss",
            "domain",
            "num_variates",
        ]
    )

kwargs = {"desc": f"Evaluating {model_name}", "total": len(df), "unit": "dataset"}

for _, row in tqdm(df.iterrows(), **kwargs):
    dataset = Dataset(name=row["name"], term=row["term"], verbose=False)

    predictor = ChronosPredictor(
        model_path="amazon/chronos-bolt-base",
        num_samples=20,
        prediction_length=dataset.prediction_length,
    )

    res = evaluate_model(
        predictor,
        test_data=dataset.test_data,
        metrics=metrics,
        batch_size=1024,
        axis=None,
        mask_invalid_label=True,
        allow_nan_forecast=False,
        seasonality=dataset.seasonality,
    )

    with open(csv_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(
            [
                dataset.config,
                model_name,
                res["MSE[mean]"][0],
                res["MSE[0.5]"][0],
                res["MAE[0.5]"][0],
                res["MASE[0.5]"][0],
                res["MAPE[0.5]"][0],
                res["sMAPE[0.5]"][0],
                res["MSIS"][0],
                res["RMSE[mean]"][0],
                res["NRMSE[mean]"][0],
                res["ND[0.5]"][0],
                res["mean_weighted_sum_quantile_loss"][0],
                row["domain"],
                row["num_variates"],
            ]
        )

Evaluating chronos_bolt_base:   0%|          | 0/97 [00:00<?, ?dataset/s]

prediction_length: 48


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
7it [04:38, 39.80s/it]
6460it [01:50, 58.66it/s]
Evaluating chronos_bolt_base:   1%|          | 1/97 [06:37<10:35:53, 397.44s/dataset]

prediction_length: 30


1it [00:04,  4.65s/it]
646it [00:00, 1142.16it/s]
Evaluating chronos_bolt_base:   2%|▏         | 2/97 [06:43<4:24:33, 167.09s/dataset] 

prediction_length: 48


6it [04:22, 43.83s/it]
6137it [00:12, 511.19it/s]
Evaluating chronos_bolt_base:   3%|▎         | 3/97 [11:19<5:39:43, 216.84s/dataset]

prediction_length: 30


1it [00:01,  1.28s/it]
90it [00:00, 965.05it/s]
Evaluating chronos_bolt_base:   4%|▍         | 4/97 [11:21<3:24:34, 131.98s/dataset]

prediction_length: 48


1it [00:26, 26.35s/it]
600it [00:01, 321.36it/s]
Evaluating chronos_bolt_base:   5%|▌         | 5/97 [11:49<2:25:17, 94.76s/dataset] 

prediction_length: 48


2it [00:46, 23.36s/it]
1092it [00:01, 843.99it/s]
Evaluating chronos_bolt_base:   6%|▌         | 6/97 [12:38<1:59:55, 79.07s/dataset]

prediction_length: 48


1it [00:04,  4.81s/it]
312it [00:00, 1042.21it/s]
Evaluating chronos_bolt_base:   7%|▋         | 7/97 [12:44<1:22:36, 55.07s/dataset]

prediction_length: 48


44it [32:05, 43.75s/it]
45000it [01:26, 521.16it/s]
Evaluating chronos_bolt_base:   8%|▊         | 8/97 [46:19<16:47:33, 679.25s/dataset]

prediction_length: 48


5it [01:12, 14.58s/it]
5000it [00:04, 1082.74it/s]
Evaluating chronos_bolt_base:   9%|▉         | 9/97 [47:38<12:00:54, 491.53s/dataset]

prediction_length: 48


18it [13:09, 43.84s/it]
18000it [00:34, 515.78it/s]
Evaluating chronos_bolt_base:  10%|█         | 10/97 [1:01:24<14:22:29, 594.82s/dataset]

prediction_length: 48


2it [00:30, 15.29s/it]
2000it [00:01, 1063.34it/s]
Evaluating chronos_bolt_base:  11%|█▏        | 11/97 [1:01:58<10:06:15, 422.97s/dataset]

prediction_length: 60


1it [00:01,  1.27s/it]
30it [00:00, 457.83it/s]
Evaluating chronos_bolt_base:  12%|█▏        | 12/97 [1:01:59<6:57:42, 294.86s/dataset] 

prediction_length: 48


1it [00:06,  6.69s/it]
140it [00:00, 195.59it/s]
Evaluating chronos_bolt_base:  13%|█▎        | 13/97 [1:02:07<4:51:07, 207.94s/dataset]

prediction_length: 48


1it [00:01,  1.77s/it]
42it [00:00, 679.26it/s]
Evaluating chronos_bolt_base:  14%|█▍        | 14/97 [1:02:10<3:21:44, 145.84s/dataset]

prediction_length: 60


1it [00:29, 29.31s/it]
630it [00:01, 503.85it/s]
Evaluating chronos_bolt_base:  15%|█▌        | 15/97 [1:02:41<2:32:03, 111.26s/dataset]

prediction_length: 12


3it [00:04,  1.44s/it]
2674it [00:02, 963.94it/s]
Evaluating chronos_bolt_base:  16%|█▋        | 16/97 [1:02:49<1:48:19, 80.24s/dataset] 

prediction_length: 30


1it [00:01,  1.14s/it]
266it [00:00, 915.24it/s]
Evaluating chronos_bolt_base:  18%|█▊        | 17/97 [1:02:51<1:15:37, 56.71s/dataset]

prediction_length: 48


8it [05:37, 42.15s/it]
7400it [02:27, 50.21it/s]
Evaluating chronos_bolt_base:  19%|█▊        | 18/97 [1:10:57<4:04:25, 185.64s/dataset]

prediction_length: 30


2it [00:57, 28.52s/it]
1850it [00:01, 990.62it/s]
Evaluating chronos_bolt_base:  20%|█▉        | 19/97 [1:11:57<3:12:08, 147.80s/dataset]

prediction_length: 48


8it [05:25, 40.64s/it]
7400it [00:40, 180.61it/s]
Evaluating chronos_bolt_base:  21%|██        | 20/97 [1:18:04<4:34:12, 213.67s/dataset]

prediction_length: 8


2it [00:04,  2.38s/it]
1110it [00:00, 1204.73it/s]
Evaluating chronos_bolt_base:  22%|██▏       | 21/97 [1:18:10<3:11:49, 151.44s/dataset]

prediction_length: 48


1it [00:05,  5.89s/it]
140it [00:01, 98.80it/s]
Evaluating chronos_bolt_base:  23%|██▎       | 22/97 [1:18:18<2:15:25, 108.34s/dataset]

prediction_length: 30


1it [00:00,  2.79it/s]
21it [00:00, 860.20it/s]
Evaluating chronos_bolt_base:  24%|██▎       | 23/97 [1:18:19<1:33:51, 76.10s/dataset] 

prediction_length: 48


1it [00:05,  5.80s/it]
140it [00:00, 303.83it/s]
Evaluating chronos_bolt_base:  25%|██▍       | 24/97 [1:18:26<1:07:16, 55.30s/dataset]

prediction_length: 8


1it [00:00,  9.14it/s]
14it [00:00, 814.89it/s]
Evaluating chronos_bolt_base:  26%|██▌       | 25/97 [1:18:26<46:40, 38.89s/dataset]  

prediction_length: 48


1it [00:05,  5.92s/it]
140it [00:01, 99.18it/s]
Evaluating chronos_bolt_base:  27%|██▋       | 26/97 [1:18:34<35:02, 29.62s/dataset]

prediction_length: 30


1it [00:00,  2.85it/s]
21it [00:00, 689.06it/s]
Evaluating chronos_bolt_base:  28%|██▊       | 27/97 [1:18:35<24:29, 21.00s/dataset]

prediction_length: 48


1it [00:05,  5.85s/it]
140it [00:00, 317.42it/s]
Evaluating chronos_bolt_base:  29%|██▉       | 28/97 [1:18:42<19:15, 16.75s/dataset]

prediction_length: 8


1it [00:00,  9.56it/s]
14it [00:00, 831.05it/s]
Evaluating chronos_bolt_base:  30%|██▉       | 29/97 [1:18:43<13:29, 11.90s/dataset]

prediction_length: 30


1it [00:32, 32.68s/it]
826it [00:00, 938.91it/s]
Evaluating chronos_bolt_base:  31%|███       | 30/97 [1:19:19<21:26, 19.21s/dataset]

prediction_length: 8


1it [00:02,  2.44s/it]
472it [00:00, 1229.88it/s]
Evaluating chronos_bolt_base:  32%|███▏      | 31/97 [1:19:22<15:54, 14.45s/dataset]

prediction_length: 12


1it [00:01,  1.66s/it]
767it [00:00, 1034.86it/s]
Evaluating chronos_bolt_base:  33%|███▎      | 32/97 [1:19:25<11:57, 11.04s/dataset]

prediction_length: 48


1it [00:18, 18.46s/it]
420it [00:03, 123.27it/s]
Evaluating chronos_bolt_base:  34%|███▍      | 33/97 [1:19:48<15:29, 14.52s/dataset]

prediction_length: 30


1it [00:00,  2.80it/s]
42it [00:00, 953.04it/s]
Evaluating chronos_bolt_base:  35%|███▌      | 34/97 [1:19:49<10:58, 10.45s/dataset]

prediction_length: 48


1it [00:17, 17.26s/it]
399it [00:00, 507.60it/s]
Evaluating chronos_bolt_base:  36%|███▌      | 35/97 [1:20:07<13:19, 12.89s/dataset]

prediction_length: 30


1it [00:04,  4.46s/it]
540it [00:00, 1065.58it/s]
Evaluating chronos_bolt_base:  37%|███▋      | 36/97 [1:20:13<10:51, 10.68s/dataset]

prediction_length: 48


6it [03:57, 39.57s/it]
5400it [00:12, 437.21it/s]
Evaluating chronos_bolt_base:  38%|███▊      | 37/97 [1:24:24<1:22:40, 82.68s/dataset]

prediction_length: 14


5it [03:11, 38.33s/it]
4227it [00:05, 749.74it/s]
Evaluating chronos_bolt_base:  39%|███▉      | 38/97 [1:27:42<1:55:33, 117.52s/dataset]

prediction_length: 48


1it [00:08,  8.93s/it]
414it [00:00, 855.22it/s]
Evaluating chronos_bolt_base:  40%|████      | 39/97 [1:27:53<1:22:28, 85.31s/dataset] 

prediction_length: 18


47it [15:13, 19.43s/it]
48000it [00:47, 1000.12it/s]
Evaluating chronos_bolt_base:  41%|████      | 40/97 [1:44:06<5:34:02, 351.62s/dataset]

prediction_length: 8


24it [03:54,  9.77s/it]
24000it [00:23, 1036.01it/s]
Evaluating chronos_bolt_base:  42%|████▏     | 41/97 [1:48:29<5:03:35, 325.28s/dataset]

prediction_length: 13


1it [00:17, 17.34s/it]
359it [00:00, 795.86it/s]
Evaluating chronos_bolt_base:  43%|████▎     | 42/97 [1:48:48<3:33:47, 233.23s/dataset]

prediction_length: 6


23it [01:41,  4.42s/it]
22974it [00:22, 1031.55it/s]
Evaluating chronos_bolt_base:  44%|████▍     | 43/97 [1:50:57<3:01:56, 202.16s/dataset]

prediction_length: 30


1it [00:08,  8.09s/it]
807it [00:00, 935.57it/s]
Evaluating chronos_bolt_base:  45%|████▌     | 44/97 [1:51:07<2:07:33, 144.41s/dataset]

prediction_length: 30


1it [00:00,  1.19it/s]
20it [00:00, 229.33it/s]
Evaluating chronos_bolt_base:  46%|████▋     | 45/97 [1:51:09<1:27:58, 101.51s/dataset]

prediction_length: 12


1it [00:00,  5.56it/s]
7it [00:00, 585.06it/s]
Evaluating chronos_bolt_base:  47%|████▋     | 46/97 [1:51:09<1:00:35, 71.28s/dataset] 

prediction_length: 8


1it [00:00,  1.17it/s]
20it [00:00, 633.24it/s]
Evaluating chronos_bolt_base:  48%|████▊     | 47/97 [1:51:11<41:55, 50.32s/dataset]  

prediction_length: 48


3it [02:01, 40.35s/it]
2740it [00:21, 128.33it/s]
Evaluating chronos_bolt_base:  49%|████▉     | 48/97 [1:53:34<1:03:50, 78.18s/dataset]

prediction_length: 30


1it [00:01,  1.87s/it]
274it [00:00, 1074.56it/s]
Evaluating chronos_bolt_base:  51%|█████     | 49/97 [1:53:37<44:25, 55.52s/dataset]  

prediction_length: 48


3it [01:58, 39.43s/it]
2603it [00:05, 514.50it/s]
Evaluating chronos_bolt_base:  52%|█████▏    | 50/97 [1:55:41<59:39, 76.15s/dataset]

prediction_length: 8


1it [00:00,  3.37it/s]
137it [00:00, 967.59it/s]
Evaluating chronos_bolt_base:  53%|█████▎    | 51/97 [1:55:42<41:06, 53.62s/dataset]

prediction_length: 30


94it [25:29, 16.28s/it]
96216it [01:29, 1079.23it/s]
Evaluating chronos_bolt_base:  54%|█████▎    | 52/97 [2:22:54<6:35:28, 527.29s/dataset]

prediction_length: 30


1it [00:00,  1.17it/s]
20it [00:00, 467.95it/s]
Evaluating chronos_bolt_base:  55%|█████▍    | 53/97 [2:22:56<4:30:59, 369.54s/dataset]

prediction_length: 12


1it [00:00, 14.22it/s]
2it [00:00, 250.97it/s]
Evaluating chronos_bolt_base:  56%|█████▌    | 54/97 [2:22:56<3:05:30, 258.84s/dataset]

prediction_length: 8


1it [00:00,  2.93it/s]
14it [00:00, 724.02it/s]
Evaluating chronos_bolt_base:  57%|█████▋    | 55/97 [2:22:57<2:07:01, 181.46s/dataset]

prediction_length: 480


5it [36:27, 437.55s/it]
Evaluating chronos_bolt_base:  57%|█████▋    | 55/97 [2:59:26<2:17:01, 195.75s/dataset]


KeyboardInterrupt: 

## Results

Running the above cell will generate a csv file called `all_results.csv` under the `results/chronos` folder containing the results for the Chronos model on the gift-eval benchmark. We can display the csv file using the follow code:

In [None]:
import pandas as pd

df = pd.read_csv(f"../results/{model_name}/all_results.csv")
df

<hr>

## Pretraining Datasets

Load the pretraining dataset information CSV file.

In [None]:
from pathlib import Path

split = "pretrain"
info_path = Path("resources") / split / "info.csv"
df = pd.read_csv(info_path)

print(f"Reading {len(df)} {split} datasets...")
df.head()

Evaluate the model on each name-term combination in the pretraining split.

In [None]:
output_dir = Path("..") / "results" / model_name / split
output_dir.mkdir(parents=True, exist_ok=True)

csv_file_path = output_dir / "all_results.csv"

with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)

    writer.writerow(
        [
            "dataset",
            "model",
            "eval_metrics/MSE[mean]",
            "eval_metrics/MSE[0.5]",
            "eval_metrics/MAE[0.5]",
            "eval_metrics/MASE[0.5]",
            "eval_metrics/MAPE[0.5]",
            "eval_metrics/sMAPE[0.5]",
            "eval_metrics/MSIS",
            "eval_metrics/RMSE[mean]",
            "eval_metrics/NRMSE[mean]",
            "eval_metrics/ND[0.5]",
            "eval_metrics/mean_weighted_sum_quantile_loss",
            "domain",
            "num_variates",
        ]
    )

kwargs = {
    "desc": f"Evaluting {model_name}",
    "total": len(df),
    "unit": "dataset",
}

for _, row in tqdm(df.iterrows(), **kwargs):
    dataset = Dataset(name=row["name"], term=row["term"])

    predictor = ChronosPredictor(
        model_path="amazon/chronos-bolt-base",
        num_samples=20,
        prediction_length=dataset.prediction_length,
    )

    res = evaluate_model(
        predictor,
        test_data=dataset.test_data,
        metrics=metrics,
        batch_size=1024,
        axis=None,
        mask_invalid_label=True,
        allow_nan_forecast=False,
        seasonality=dataset.seasonality,
    )

    with open(csv_file_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(
            [
                dataset.config,
                model_name,
                res["MSE[mean]"][0],
                res["MSE[0.5]"][0],
                res["MAE[0.5]"][0],
                res["MASE[0.5]"][0],
                res["MAPE[0.5]"][0],
                res["sMAPE[0.5]"][0],
                res["MSIS"][0],
                res["RMSE[mean]"][0],
                res["NRMSE[mean]"][0],
                res["ND[0.5]"][0],
                res["mean_weighted_sum_quantile_loss"][0],
            ]
        )