In [None]:
#| default_exp distributed.utils

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Distributed utils

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
from typing import Optional

from statsforecast.core import ParallelBackend

In [None]:
#| export
def forecast(
    df,
    models,
    freq,
    h,
    fallback_model=None,
    X_df=None,
    level=None,
    parallel: Optional["ParallelBackend"] = None,
):
    backend = parallel if parallel is not None else ParallelBackend()
    return backend.forecast(df, models, freq, fallback_model, h=h, X_df=X_df, level=level)

In [None]:
#| export
def cross_validation(
    df,
    models,
    freq,
    h,
    n_windows=1,
    step_size=1,
    test_size=None,
    input_size=None,
    parallel: Optional["ParallelBackend"] = None,
):
    backend = parallel if parallel is not None else ParallelBackend()
    return backend.cross_validation(
        df,
        models,
        freq,
        h=h,
        n_windows=n_windows,
        step_size=step_size,
        test_size=test_size,
        input_size=input_size,
    )

In [None]:
#| hide
#| eval: false
from statsforecast.core import StatsForecast
from statsforecast.distributed.fugue import FugueBackend
from statsforecast.models import Naive
from statsforecast.utils import generate_series

In [None]:
#| hide
#| eval: false
df = generate_series(10).reset_index()
df['unique_id'] = df['unique_id'].astype(str)

backend = FugueBackend()
#forecast
fcst_fugue = forecast(df, models=[Naive()], freq='D', h=12, parallel=backend)
fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
test_eq(fcst_fugue, fcst_stats.reset_index())
#cross validation
fcst_fugue = cross_validation(df, models=[Naive()], freq='D', h=12, parallel=backend)
fcst_stats = StatsForecast(models=[Naive()], freq='D').cross_validation(df=df, h=12)
test_eq(fcst_fugue, fcst_stats.reset_index())
# fallback model
class FailNaive:
    def forecast(self):
        pass
    def __repr__(self):
        return 'Naive'
fcst = backend.forecast(df, models=[FailNaive()], freq='D', fallback_model=Naive(), h=12)
fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
test_eq(fcst, fcst_stats.reset_index())