# Simple AutoML for time series with Ray AIR

In [1]:
!pip install statsforecast



In [14]:
from typing import Any, List, Union, Callable, Dict, Type, Tuple, Optional
import time
import itertools
import pandas as pd
import numpy as np
from collections import defaultdict
from statsforecast import StatsForecast
from statsforecast.models import ETS, AutoARIMA
from sklearn.metrics import mean_squared_error, mean_absolute_error

import statsforecast_trainer

import ray
from ray import air, tune

In [16]:
if ray.is_initialized():
    ray.shutdown()
ray.init(runtime_env={"pip": ["statsforecast"], "working_dir": "."})

2022-10-24 16:46:49,102	INFO worker.py:1229 -- Using address localhost:9031 set in the environment variable RAY_ADDRESS
2022-10-24 16:46:50,058	INFO worker.py:1341 -- Connecting to existing Ray cluster at address: 172.31.107.241:9031...
2022-10-24 16:46:50,063	INFO worker.py:1518 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://console.anyscale-staging.com/api/v2/sessions/ses_J737nkSmzssxqHpQ2RQREEaP/services?redirect_to=dashboard [39m[22m
2022-10-24 16:46:50,645	INFO packaging.py:527 -- Creating a file package for local directory '.'.
2022-10-24 16:46:50,650	INFO packaging.py:354 -- Pushing file package 'gcs://_ray_pkg_16da713b42dcb8fa.zip' (1.16MiB) to Ray cluster...
2022-10-24 16:46:50,660	INFO packaging.py:367 -- Successfully pushed file package 'gcs://_ray_pkg_16da713b42dcb8fa.zip'.


0,1
Python version:,3.8.13
Ray version:,3.0.0.dev0
Dashboard:,http://console.anyscale-staging.com/api/v2/sessions/ses_J737nkSmzssxqHpQ2RQREEaP/services?redirect_to=dashboard


## Read the dataset from S3 using `ray.data`

In [17]:
from pyarrow.dataset import field

def transform_ds(batch: pd.DataFrame) -> pd.DataFrame:
    # StatsForecasts expects specific column names!
    batch = batch.rename(
        columns={"item_id": "unique_id", "timestamp": "ds", "demand": "y"}
    )
    batch["unique_id"] = batch["unique_id"].astype(str)
    batch["ds"] = pd.to_datetime(batch["ds"])
    batch = batch.dropna()
    constant = 10
    batch["y"] += constant
    return batch

# Only consider a single time series for this example.
partition_ids = ["FOODS_1_001_CA_1"]

ds = ray.data.read_parquet(
    "s3://anonymous@m5-benchmarks/data/train/target.parquet",
    columns=["item_id", "timestamp", "demand"],
    filter=field("item_id").isin(partition_ids)
).map_batches(transform_ds, batch_format="pandas")

Parquet Files Sample:   0%|          | 0/1 [00:00<?, ?it/s]
Parquet Files Sample: 100%|██████████| 1/1 [00:04<00:00,  4.85s/it]tasks pid=22051) 
Read->Map_Batches: 100%|██████████| 1/1 [00:09<00:00,  9.76s/it]


## Create a `statsforecast` AIR Trainer

We can get the benefits of AIR preprocessors by subclassing the `BaseTrainer` to perform our training loop, which should perform cross-validation and report back metrics.

```{literalinclude} statsforecast_trainer.py
    :language: python
    :start-after: __statsforecast_trainer_start__
    :end-before: __statsforecast_trainer_end__
```

In [21]:
from statsforecast_trainer import StatsforecastTrainer

statsforecast_trainer = StatsforecastTrainer(
    datasets={"train": ds},
)

## Define the search space

We can use Ray's Optuna integration to define a conditional search space with a function. Here, we first sample the model class, then sample parameters depending on which model was picked.

In [22]:
from ray.tune.search.optuna import OptunaSearch

def optuna_search_space(trial) -> Optional[Dict[str, Any]]:
    search_space = {
        AutoARIMA: {},
        ETS: {
            "season_length": [6, 7],
            "model": ["ZNA", "ZZZ"],
        }
    }

    model_type = trial.suggest_categorical("model_cls", list(search_space.keys()))

    # Conditional search space based on the model_type that was chosen
    for param, param_space in search_space[model_type].items():
        trial.suggest_categorical(param, param_space)

    # Return contant params
    return {
        "n_splits": 5,
        "test_size": 1,
        "parallelize_cv": True,
        "freq": "D",
    }

algo = OptunaSearch(space=optuna_search_space, metric="mse_mean", mode="min")

[I 2022-10-24 16:47:28,289] A new study created in memory with name: optuna


## Create a Tuner and specify a grid search

In [23]:
tuner = tune.Tuner(
    statsforecast_trainer,
    tune_config=tune.TuneConfig(
        metric="mse_mean",
        mode="min",
        search_alg=algo,  # OptunaSearch search algorithm
        num_samples=5,
    ),
)
result_grid = tuner.fit()

0,1
Current time:,2022-10-24 16:48:10
Running for:,00:00:41.31
Memory:,4.3/62.0 GiB

Trial name,status,loc,freq,model,model_cls,n_splits,parallelize_cv,season_length,test_size,iter,total time (s),mse_mean,mse_std,mae_mean
StatsforecastTrainer_3821d39e,TERMINATED,172.31.107.241:22433,D,ZZZ,<class 'statsfo_3750,5,True,6.0,1,1,38.3345,0.655225,0.335512,0.737604
StatsforecastTrainer_39e294de,TERMINATED,172.31.107.241:22474,D,,<class 'statsfo_f1b0,5,True,,1,1,16.6465,0.658737,0.336924,0.739211
StatsforecastTrainer_39e469f8,TERMINATED,172.31.107.241:22476,D,,<class 'statsfo_f1b0,5,True,,1,1,16.0207,0.658737,0.336924,0.739211
StatsforecastTrainer_39e5f9c6,TERMINATED,172.31.107.241:22478,D,,<class 'statsfo_f1b0,5,True,,1,1,15.5794,0.658737,0.336924,0.739211
StatsforecastTrainer_39e78dea,TERMINATED,172.31.107.241:22483,D,,<class 'statsfo_f1b0,5,True,,1,1,16.3824,0.658737,0.336924,0.739211


Trial name,cutoff_values,cv_time,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,mae_mean,mae_std,mse_mean,mse_std,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,unique_ids,warmup_time
StatsforecastTrainer_3821d39e,['2016-05-17T00:00:00.000000000' '2016-05-18T00:00:00.000000000'  '2016-05-19T00:00:00.000000000' '2016-05-20T00:00:00.000000000'  '2016-05-21T00:00:00.000000000'],38.3133,2022-10-24_16-48-09,True,,a79c6935dd35427d856e6bc7cfc155d2,"1_freq=D,model=ZZZ,model_cls=class_statsforecast_models_ETS,n_splits=5,parallelize_cv=True,season_length=6,test_size=1",ip-172-31-107-241,1,0.737604,0.333415,0.655225,0.335512,172.31.107.241,22433,True,38.3345,38.3345,38.3345,1666655289,0,,1,3821d39e,['FOODS_1_001_CA_1'],0.0931134
StatsforecastTrainer_39e294de,['2016-05-17T00:00:00.000000000' '2016-05-18T00:00:00.000000000'  '2016-05-19T00:00:00.000000000' '2016-05-20T00:00:00.000000000'  '2016-05-21T00:00:00.000000000'],16.6249,2022-10-24_16-47-51,True,,957f1ce6f4274ccea883f2b0645b863f,"2_freq=D,model_cls=class_statsforecast_models_AutoARIMA,n_splits=5,parallelize_cv=True,test_size=1",ip-172-31-107-241,1,0.739211,0.335117,0.658737,0.336924,172.31.107.241,22474,True,16.6465,16.6465,16.6465,1666655271,0,,1,39e294de,['FOODS_1_001_CA_1'],0.11385
StatsforecastTrainer_39e469f8,['2016-05-17T00:00:00.000000000' '2016-05-18T00:00:00.000000000'  '2016-05-19T00:00:00.000000000' '2016-05-20T00:00:00.000000000'  '2016-05-21T00:00:00.000000000'],15.9995,2022-10-24_16-47-50,True,,c8abe455c40b4ff29e98a1011ad4203b,"3_freq=D,model_cls=class_statsforecast_models_AutoARIMA,n_splits=5,parallelize_cv=True,test_size=1",ip-172-31-107-241,1,0.739211,0.335117,0.658737,0.336924,172.31.107.241,22476,True,16.0207,16.0207,16.0207,1666655270,0,,1,39e469f8,['FOODS_1_001_CA_1'],0.114835
StatsforecastTrainer_39e5f9c6,['2016-05-17T00:00:00.000000000' '2016-05-18T00:00:00.000000000'  '2016-05-19T00:00:00.000000000' '2016-05-20T00:00:00.000000000'  '2016-05-21T00:00:00.000000000'],15.5574,2022-10-24_16-47-50,True,,6ed1201bfb79451a9def4aa21e1e775a,"4_freq=D,model_cls=class_statsforecast_models_AutoARIMA,n_splits=5,parallelize_cv=True,test_size=1",ip-172-31-107-241,1,0.739211,0.335117,0.658737,0.336924,172.31.107.241,22478,True,15.5794,15.5794,15.5794,1666655270,0,,1,39e5f9c6,['FOODS_1_001_CA_1'],0.121955
StatsforecastTrainer_39e78dea,['2016-05-17T00:00:00.000000000' '2016-05-18T00:00:00.000000000'  '2016-05-19T00:00:00.000000000' '2016-05-20T00:00:00.000000000'  '2016-05-21T00:00:00.000000000'],16.3611,2022-10-24_16-47-51,True,,8f00da992fff45ada93ee280313cb25f,"5_freq=D,model_cls=class_statsforecast_models_AutoARIMA,n_splits=5,parallelize_cv=True,test_size=1",ip-172-31-107-241,1,0.739211,0.335117,0.658737,0.336924,172.31.107.241,22483,True,16.3824,16.3824,16.3824,1666655271,0,,1,39e78dea,['FOODS_1_001_CA_1'],0.112233


2022-10-24 16:47:50,569	INFO tensorboardx.py:267 -- Removed the following hyperparameter values when logging to tensorboard: {'model_cls': <class 'statsforecast.models.AutoARIMA'>}
2022-10-24 16:47:50,939	INFO tensorboardx.py:267 -- Removed the following hyperparameter values when logging to tensorboard: {'model_cls': <class 'statsforecast.models.AutoARIMA'>}
2022-10-24 16:47:51,319	INFO tensorboardx.py:267 -- Removed the following hyperparameter values when logging to tensorboard: {'model_cls': <class 'statsforecast.models.AutoARIMA'>}
2022-10-24 16:47:51,576	INFO tensorboardx.py:267 -- Removed the following hyperparameter values when logging to tensorboard: {'model_cls': <class 'statsforecast.models.AutoARIMA'>}
2022-10-24 16:48:10,012	INFO tensorboardx.py:267 -- Removed the following hyperparameter values when logging to tensorboard: {'model_cls': <class 'statsforecast.models.ETS'>}
2022-10-24 16:48:10,129	INFO tune.py:777 -- Total run time: 41.42 seconds (41.30 seconds for the tuni

In [24]:
best_result = result_grid.get_best_result()

print("Best mse_mean:", best_result.metrics["mse_mean"])
print("Best mae_mean:", best_result.metrics["mae_mean"])

Best mse_mean: 0.65522546
Best mae_mean: 0.7376043


In [26]:
best_result.config

{'model_cls': statsforecast.models.ETS,
 'season_length': 6,
 'model': 'ZZZ',
 'n_splits': 5,
 'test_size': 1,
 'parallelize_cv': True,
 'freq': 'D'}

In [27]:
best_result.metrics_dataframe

Unnamed: 0,mse_mean,mse_std,mae_mean,mae_std,unique_ids,cutoff_values,cv_time,time_this_iter_s,should_checkpoint,done,...,date,timestamp,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time
0,0.655225,0.335512,0.737604,0.333415,['FOODS_1_001_CA_1'],['2016-05-17T00:00:00.000000000' '2016-05-18T0...,38.313303,38.334548,True,False,...,2022-10-24_16-48-09,1666655289,38.334548,22433,ip-172-31-107-241,172.31.107.241,38.334548,0,1,0.093113
