In [None]:
#| default_exp experiments.merlion.train

In [None]:
#| export
import pandas as pd
from national.experiments.merlion import model as merlion_model
from national.experiments.merlion import data_splits

In [None]:
#| export


class Train():

    def __init__(
        self,
        kpi: str,
        freq: str,
        df: pd.DataFrame,
        load_models: bool = False,
        val_frac: float = 0.15,
        **args,
    ):
        # super().__init__()

        self.data = data_splits.Data(
            kpi=kpi,
            df=df,
            freq=freq,
            val_frac=val_frac,
            **args,
        )

        self.kpi = kpi
        self.freq = freq

        self.include_selector = 'D' in freq and len(freq) > 1 and int(
            freq[0]) < 7

        self.models = merlion_model.Models(
            granularity=freq,
            load_models=load_models,
        )

        _models = [
            self.models.arima,
            self.models.prophet,
            self.models.mses,
            # self.models.ensemble,
            # self.models.partial_ensemble
        ]
        if self.include_selector:
            _models.append(self.models.selector)

        for _model in _models:
            if not _model.load_model:
                print(_model.name)

                forecast, stderr = _model.model.train(self.data.train, )
                _model.forecast.train = forecast
                _model.stderr.train = stderr

In [None]:
from national.data_preprocessing.date_features import Data
from national.experiments.merlion import train

query = "product=='Haba' & state=='Sinaloa'"
test_frac = 0.15
val_frac = 0.15
_df= Data().df.query(query)

_train = train.Train(
    df=_df,
    kpi='price',
    freq='W-MON',
    # test_frac=test_frac,
    val_frac=val_frac,
)

Arima


06:50:51 - cmdstanpy - INFO - Chain [1] start processing


Prophet


06:50:51 - cmdstanpy - INFO - Chain [1] done processing


MSES


In [None]:
_train.data.test

Empty DataFrame
Columns: [date]
Index: []

In [None]:
_train.data.train.univariates['price'].time_stamps

[1263168000.0,
 1263772800.0,
 1264377600.0,
 1264982400.0,
 1265587200.0,
 1266796800.0,
 1267401600.0,
 1268006400.0,
 1268611200.0,
 1269216000.0,
 1269820800.0,
 1270425600.0,
 1271030400.0,
 1271635200.0,
 1272240000.0,
 1272844800.0,
 1274659200.0,
 1275264000.0,
 1276473600.0,
 1277683200.0,
 1278288000.0,
 1278892800.0,
 1279497600.0,
 1280102400.0,
 1280707200.0,
 1281312000.0,
 1281916800.0,
 1282521600.0,
 1283126400.0,
 1283731200.0,
 1284940800.0,
 1285545600.0,
 1286150400.0,
 1286755200.0,
 1287360000.0,
 1287964800.0,
 1289174400.0,
 1290384000.0,
 1290988800.0,
 1291593600.0,
 1292803200.0,
 1293408000.0,
 1294012800.0,
 1294617600.0,
 1295222400.0,
 1295827200.0,
 1296432000.0,
 1297036800.0,
 1298246400.0,
 1298851200.0,
 1300060800.0,
 1300665600.0,
 1301270400.0,
 1301875200.0,
 1302480000.0,
 1303084800.0,
 1303689600.0,
 1304294400.0,
 1305504000.0,
 1306108800.0,
 1306713600.0,
 1307318400.0,
 1307923200.0,
 1308528000.0,
 1309132800.0,
 1309737600.0,
 131034240

In [None]:
from national.data_preprocessing.date_features import Data
from national.experiments.merlion import train

query = "product=='Haba'"
test_frac = 0.15
val_frac = 0.15
_df= Data().df.query(query)

train.Train(
    df=_df,
    kpi='price',
    freq='W-MON',
    # test_frac=test_frac,
    val_frac=val_frac,
)

Arima


06:51:11 - cmdstanpy - INFO - Chain [1] start processing


Prophet


06:51:11 - cmdstanpy - INFO - Chain [1] done processing


MSES


<national.experiments.merlion.train.Train at 0x7f7c7037fe80>

In [None]:
_train.data.train.univariates['price'].time_stamps

[1263168000.0,
 1263772800.0,
 1264377600.0,
 1264982400.0,
 1265587200.0,
 1266796800.0,
 1267401600.0,
 1268006400.0,
 1268611200.0,
 1269216000.0,
 1269820800.0,
 1270425600.0,
 1271030400.0,
 1271635200.0,
 1272240000.0,
 1272844800.0,
 1274659200.0,
 1275264000.0,
 1276473600.0,
 1277683200.0,
 1278288000.0,
 1278892800.0,
 1279497600.0,
 1280102400.0,
 1280707200.0,
 1281312000.0,
 1281916800.0,
 1282521600.0,
 1283126400.0,
 1283731200.0,
 1284940800.0,
 1285545600.0,
 1286150400.0,
 1286755200.0,
 1287360000.0,
 1287964800.0,
 1289174400.0,
 1290384000.0,
 1290988800.0,
 1291593600.0,
 1292803200.0,
 1293408000.0,
 1294012800.0,
 1294617600.0,
 1295222400.0,
 1295827200.0,
 1296432000.0,
 1297036800.0,
 1298246400.0,
 1298851200.0,
 1300060800.0,
 1300665600.0,
 1301270400.0,
 1301875200.0,
 1302480000.0,
 1303084800.0,
 1303689600.0,
 1304294400.0,
 1305504000.0,
 1306108800.0,
 1306713600.0,
 1307318400.0,
 1307923200.0,
 1308528000.0,
 1309132800.0,
 1309737600.0,
 131034240