In [1]:
%pip install statsforecast
%pip install statsmodels
%pip install pmdarima
%pip install pandas
%pip install plotly
%pip install nbformat>=4.2.0

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [103]:
from dataclasses import dataclass, asdict
from statsforecast.models import ARIMA
from typing import Optional, Callable
from pmdarima.arima import auto_arima
from numpy.typing import NDArray
import plotly.graph_objs as go
import statsmodels.api as sm
from arima import Model
import pandas as pd
import time

In [136]:
@dataclass
class SearchParams:
    start_p: int = 0
    max_p: int = 3
    d: int = 0
    max_d: int = 2
    start_q: int = 0
    max_q: int = 3
    start_P: int = 0
    max_P: int = 1
    start_Q: int = 0
    max_Q: int = 1
    D: int = 0
    max_D: int = 2
    seasonal: bool = True
    m: int = 12
    maxiter: int = 50
    information_criterion: str = "aic"
    suppress_warnings: bool = True
    trace: bool = True

    def count_max_enumerations(self) -> None:
        p = self.max_p - self.start_p + 1
        d = self.max_d - self.d + 1
        q = self.max_q - self.start_q + 1
        P = self.max_P - self.start_P + 1
        D = self.max_D - self.D + 1
        Q = self.max_Q - self.start_Q + 1

        n = p * d * q * P * D * Q
        print(f"Max possible iterations = {n}")

In [44]:
@dataclass
class HyperParams:
    order: tuple[int, int, int]
    seasonal_order: tuple[int, int, int, int]

In [119]:
def timeit(func: Callable) -> Callable:

    def inner(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"Time for {func.__name__}: {round(end - start, 2)} seconds.")
        return result
    
    return inner

In [120]:
@timeit
def my_preds(hyperparams: HyperParams, y_train: NDArray, h: int) -> NDArray:
    m = Model.sarima(order=hyperparams.order, seasonal_order=hyperparams.seasonal_order)
    m.fit(y=y_train)
    return m.predict(h=h, x=None)

In [121]:
@timeit
def sf_preds(hyperparams: HyperParams, y_train: NDArray, h: int) -> NDArray:
    m = Model.sarima(order=hyperparams.order, seasonal_order=hyperparams.seasonal_order)
    m.fit(y=y_train)
    m = ARIMA(order=hyperparams.order,
              seasonal_order=hyperparams.seasonal_order[: -1],
              season_length=hyperparams.seasonal_order[-1]
              )
    return m.forecast(y=y_train, h=h)["mean"]

In [122]:
@timeit
def sm_preds(hyperparams: HyperParams, y_train: NDArray, h: int) -> NDArray:
    m = (sm.tsa.statespace
        .SARIMAX(endog=y_train, order=hyperparams.order, seasonal_order=hyperparams.seasonal_order)
        .fit()
    )
    return m.predict(start=y_train.size, end=y_train.size + h - 1)

In [159]:
class Run:

    def __init__(
            self,
            df: pd.DataFrame,
            ds: str,
            y: str,
            h: int
    ) -> None:
        self.df = df
        self.ds = ds
        self.y = y
        self.h = h

        self.y_train = df[y].values[:-h]
        self.test_dates = df[ds].values[-h:]
        self.preds: dict[str, NDArray] = {}
        self.hyper_params: Optional[HyperParams] = None

    @timeit
    def assign_params(
            self,
            search_params: Optional[SearchParams] = None,
            default_params: Optional[HyperParams] = None
    ) -> None:
        if ((search_params is None and default_params is None)
            or (search_params is not None and default_params is not None)
            ):
            raise ValueError("Must pass exactly one of search_params or default_params.")
        elif search_params is not None:
            self.hyper_params = self._search_hyperparameters(search_params)
        else:
            self.hyper_params = default_params

    def _search_hyperparameters(self, search_params: SearchParams) -> HyperParams:
        search = auto_arima(y=self.y_train, **asdict(search_params))
        params = search.get_params()
        return HyperParams(params["order"], params["seasonal_order"])
    
    def add_preds(self, func: Callable, name: str) -> None:
        self.preds[name] = func(self.hyper_params, self.y_train, self.h)

    def plot(self) -> None:
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=self.df[self.ds], y=self.df[self.y], mode='lines', name=self.y))
        for name, y in self.preds.items():
            fig.add_trace(go.Scatter(x=self.test_dates, y=y, mode='lines', name=name))
        fig.update_layout(showlegend=True)
        fig.show()

# SUNSPOTS

In [152]:
search_params = SearchParams(
    start_p=0,
    max_p=4,
    d=0,
    max_d=1,
    start_q=0,
    max_q=4,
    start_P=0,
    max_P=2,
    start_Q=0,
    max_Q=2,
    D=1,
    seasonal=True,
    m=12,
    maxiter=50,
)
search_params.count_max_enumerations()

Max possible iterations = 900


In [155]:
sunspots = Run(
    df=sm.datasets.sunspots.load_pandas().data,
    ds="YEAR",
    y="SUNACTIVITY",
    h=45
)
sunspots.assign_params(
        # search_params=search_params,
    default_params=HyperParams((5, 0, 0), (2, 1, 0, 43))
)
sunspots.add_preds(my_preds, "MyArima")
sunspots.add_preds(sf_preds, "StatsForecast")
sunspots.add_preds(sm_preds, "StatsModels")
sunspots.plot()

Time for my_preds: 0.0 seconds.
Time for sf_preds: 1.38 seconds.


# AUSSIE BEER PRODUCTION

In [135]:
search_params = SearchParams(
    start_p=0,
    max_p=6,
    d=0,
    max_d=1,
    start_q=0,
    max_q=6,
    start_P=0,
    max_P=2,
    start_Q=0,
    max_Q=2,
    D=2,
    seasonal=True,
    m=12,
    maxiter=50,
)
search_params.count_max_enumerations()

Max iterations = 882


In [158]:
beer = Run(
    df=pd.read_csv("https://raw.githubusercontent.com/ejgao/Time-Series-Datasets/master/monthly-beer-production-in-austr.csv"),
    ds="Month",
    y="Monthly beer production",
    h=12
)
beer.assign_params(
    # search_params=search_params,
    default_params=HyperParams((5, 0, 0), (1, 2, 1, 12))
)
beer.add_preds(my_preds, "MyArima")
beer.add_preds(sf_preds, "StatsForecast")
sunspots.add_preds(sm_preds, "StatsModels")
beer.plot()

Time for my_preds: 0.01 seconds.
Time for sf_preds: 0.96 seconds.
Time for sm_preds: 15.05 seconds.
