In [None]:
#| default_exp experiments.merlion.inference

In [None]:
#| export

from pathlib import Path
from typing import List
import pandas as pd

from merlion.utils.time_series import TimeSeries


from national.experiments.merlion import model
from national.experiments.merlion import data_splits
from national.experiments.merlion import train
from national.util.constants import ARIMA_MAX_STEPS, SMES_MAX_STEPS

In [None]:
#| export
def _forecast(
    data: data_splits.Data,
    models: List[model.Model],
):

    if len(data.time_stamps.test) > 0:

        for model in models:

            #if not model.load_model:
            forecast, stderr = model.model.forecast(
                time_stamps=data.time_stamps.test,
                time_series_prev=data.train,
            )

            model.forecast.test = forecast
            model.stderr.test = stderr

    if len(data.time_stamps.val) > 0:

        for model in models:
            #if not model.load_model:

            forecast, stderr = model.model.forecast(
                time_stamps=data.time_stamps.val,
                time_series_prev=data.train,
            )

            model.forecast.val = forecast
            model.stderr.val = stderr


    if len(data.time_stamps.future) > 0:

        for model in models:
            #if not model.load_model:

            forecast, stderr = model.model.forecast(
                time_stamps=data.time_stamps.future,
                time_series_prev=data.train,
            )

            model.forecast.future = forecast
            model.stderr.future = stderr

In [None]:
#| export


class Inference(train.Train):

    def __init__(
        self,
        kpi: str,
        freq: str,
        df: pd.DataFrame,
        test_frac: float = 0.15,
        val_frac: float = 0.15,
        **kwargs,
    ):
        super().__init__(
            kpi=kpi,
            df=df,
            freq=freq,
            test_frac=test_frac,
            val_frac=val_frac,
            **kwargs,
        )

        self.horizon = [SMES_MAX_STEPS,
                        ARIMA_MAX_STEPS][int(ARIMA_MAX_STEPS < SMES_MAX_STEPS)]

        sub_test_data = self.data.test[:self.horizon]

        self.data.sub_test = sub_test_data

        self.data.time_stamps.val = self.data.val.univariates[
            self.data.val.names[0]].time_stamps

        self.data.time_stamps.test = sub_test_data.univariates[
            sub_test_data.names[0]].time_stamps

        self.data.time_stamps.train = self.data.train.univariates[
            self.data.train.names[0]].time_stamps

        self.data.time_stamps.future = self.data.future.univariates[
            self.data.future.names[0]].time_stamps

        self._models = [
            self.models.prophet,
            self.models.arima,
            self.models.mses,
            # self.models.ensemble,
            # self.models.partial_ensemble,
        ]

        if self.include_selector:
            self._models = self._models + [self.models.selector]

        print('Start forecast')
        _forecast(
            data=self.data,
            models=self._models,
        )

    def save_models(self):
        _model: model.Model
        for _model in self._models:
            _model.model.save(_model.model_path)

In [None]:
from national.data_preprocessing.date_features import Data
query = "product=='Haba' & state=='Sinaloa'"
test_frac = 0.10
val_frac = 0.15
nal_df = Data().df.query(query)
inf = Inference(
    df=nal_df,
    kpi="price",
    freq="M",
    load_models=False,
    periods=3,
    test_frac=test_frac,
    val_frac=val_frac,
)

Arima


15:24:28 - cmdstanpy - INFO - Chain [1] start processing
15:24:29 - cmdstanpy - INFO - Chain [1] done processing


Prophet
MSES
Start forecast


In [None]:
inf.freq

AttributeError: 'Inference' object has no attribute 'freq'

In [None]:
inf.data.future

            price
time             
2021-01-31    0.0
2021-02-28    0.0

In [None]:
inf.data.train.univariates['price'].time_stamps

[1264896000.0,
 1267315200.0,
 1269993600.0,
 1272585600.0,
 1275264000.0,
 1277856000.0,
 1280534400.0,
 1283212800.0,
 1285804800.0,
 1288483200.0,
 1291075200.0,
 1293753600.0,
 1296432000.0,
 1298851200.0,
 1301529600.0,
 1304121600.0,
 1306800000.0,
 1309392000.0,
 1312070400.0,
 1314748800.0,
 1317340800.0,
 1320019200.0,
 1322611200.0,
 1325289600.0,
 1327968000.0,
 1330473600.0,
 1333152000.0,
 1335744000.0,
 1338422400.0,
 1341014400.0,
 1343692800.0,
 1346371200.0,
 1348963200.0,
 1351641600.0,
 1354233600.0,
 1356912000.0,
 1359590400.0,
 1362009600.0,
 1364688000.0,
 1367280000.0,
 1369958400.0,
 1372550400.0,
 1375228800.0,
 1377907200.0,
 1380499200.0,
 1383177600.0,
 1385769600.0,
 1388448000.0,
 1391126400.0,
 1393545600.0,
 1396224000.0,
 1398816000.0,
 1401494400.0,
 1404086400.0,
 1406764800.0,
 1409443200.0,
 1412035200.0,
 1414713600.0,
 1417305600.0,
 1419984000.0,
 1422662400.0,
 1425081600.0,
 1427760000.0,
 1430352000.0,
 1433030400.0,
 1435622400.0,
 143830080

In [None]:
inf.models.arima.forecast.future

                price
time                 
2021-01-31  44.766750
2021-02-28  44.881279

In [None]:
inf.models.arima.forecast.test

In [None]:
from national.data_preprocessing.date_features import Data
from national.experiments.merlion import inference as merlion_inference
query = "product=='Haba' & state=='Sinaloa'"
test_frac = 0.15
val_frac = 0.15
nal_df = Data().df.query(query)
inf = Inference(
    df=nal_df,
    kpi="price",
    freq="4W",
    periods=2,
    load_models=False,
    test_frac=test_frac,
    val_frac=val_frac,
)

Arima


15:25:20 - cmdstanpy - INFO - Chain [1] start processing
15:25:20 - cmdstanpy - INFO - Chain [1] done processing


Prophet
MSES
Start forecast


In [None]:
from national.data_preprocessing.date_features import Data
query = "product=='Haba'"
test_frac = 0.15
val_frac = 0.15
nal_df = Data().df.query(query)
inf = Inference(
    df=nal_df,
    kpi="price",
    freq="4W",
    load_models=False,
    test_frac=test_frac,
    val_frac=val_frac,
)

Arima


15:25:28 - cmdstanpy - INFO - Chain [1] start processing


Prophet


15:25:29 - cmdstanpy - INFO - Chain [1] done processing


MSES
Start forecast
