Skip to content

Commit

Permalink
transfer adam_forecaster
Browse files Browse the repository at this point in the history
  • Loading branch information
ltsaprounis committed Feb 13, 2024
1 parent d1641dc commit c0b48b2
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 4 deletions.
52 changes: 49 additions & 3 deletions python/smooth/adam_general/sma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from smooth.adam_general._adam_general import adam_fitter
from smooth.adam_general._adam_general import adam_fitter, adam_forecaster
from smooth.adam_general.adam_profile import adam_profile_creator


Expand Down Expand Up @@ -75,6 +75,52 @@ def creator_sma(order):
backcast=True,
)

return adam_fitted
fitted_args = dict(
matrixVt=mat_Vt,
matrixWt=mat_Wt,
matrixF=mat_F,
vectorG=vec_G,
lags=lags_model_all,
indexLookupTable=index_lookup_table,
profilesRecent=profiles_recent_table,
E=E_type,
T=T_type,
S=S_type,
nNonSeasonal=components_num_ETS,
nSeasonal=components_num_ETS_seasonal,
nArima=order,
nXreg=xreg_number,
constant=constant_required,
vectorYt=y_in_sample,
vectorOt=ot,
backcast=True,
)

return adam_fitted, fitted_args

sma_fitted, fitted_args = creator_sma(order=order)

# need to convert some inputs to the expected dtypes. This is a temporary fix.
fitted_args["lags"] = np.array(fitted_args["lags"], dtype="uint64")
fitted_args["indexLookupTable"] = np.array(
fitted_args["indexLookupTable"], dtype="uint64"
)

sma_forecast = adam_forecaster(
matrixWt=fitted_args["matrixWt"],
matrixF=fitted_args["matrixF"],
lags=fitted_args["lags"],
indexLookupTable=fitted_args["indexLookupTable"],
profilesRecent=sma_fitted["profile"],
E=fitted_args["E"],
T=fitted_args["T"],
S=fitted_args["S"],
nNonSeasonal=fitted_args["nNonSeasonal"],
nSeasonal=fitted_args["nSeasonal"],
nArima=fitted_args["nArima"],
nXreg=fitted_args["nXreg"],
constant=fitted_args["constant"],
horizon=h,
)

return creator_sma(order=order)
return sma_forecast
2 changes: 1 addition & 1 deletion python/smooth/adam_general/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
if __name__ == "__main__":
y = np.arange(0, 100)
results = sma(y, order=5)
print(results["yFitted"])
print(results)
43 changes: 43 additions & 0 deletions src/python_examples/adamGeneral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,34 @@ py::dict adamFitter(arma::mat &matrixVt, arma::mat const &matrixWt, arma::mat &m
return result;
}

/* # Function produces the point forecasts for the specified model */
arma::vec adamForecaster(arma::mat const &matrixWt, arma::mat const &matrixF,
arma::uvec lags, arma::umat const &indexLookupTable, arma::mat profilesRecent,
char const &E, char const &T, char const &S,
unsigned int const &nNonSeasonal, unsigned int const &nSeasonal,
unsigned int const &nArima, unsigned int const &nXreg, bool const &constant,
unsigned int const &horizon)
{
// unsigned int lagslength = lags.n_rows;
unsigned int nETS = nNonSeasonal + nSeasonal;
unsigned int nComponents = indexLookupTable.n_rows;

arma::vec vecYfor(horizon, arma::fill::zeros);

/* # Fill in the new xt matrix using F. Do the forecasts. */
for (unsigned int i = 0; i < horizon; i = i + 1)
{
vecYfor.row(i) = adamWvalue(profilesRecent(indexLookupTable.col(i)), matrixWt.row(i), E, T, S,
nETS, nNonSeasonal, nSeasonal, nArima, nXreg, nComponents, constant);

profilesRecent(indexLookupTable.col(i)) = adamFvalue(profilesRecent(indexLookupTable.col(i)),
matrixF, E, T, S, nETS, nNonSeasonal, nSeasonal, nArima, nComponents, constant);
}

// return List::create(Named("matVt") = matrixVtnew, Named("yForecast") = vecYfor);
return vecYfor;
}

PYBIND11_MODULE(_adam_general, m)
{
m.doc() = "Adam code"; // module docstring
Expand All @@ -184,4 +212,19 @@ PYBIND11_MODULE(_adam_general, m)
py::arg("vectorYt"),
py::arg("vectorOt"),
py::arg("backcast"));
m.def("adam_forecaster", &adamForecaster, "forecasts the adam model",
py::arg("matrixWt"),
py::arg("matrixF"),
py::arg("lags"),
py::arg("indexLookupTable"),
py::arg("profilesRecent"),
py::arg("E"),
py::arg("T"),
py::arg("S"),
py::arg("nNonSeasonal"),
py::arg("nSeasonal"),
py::arg("nArima"),
py::arg("nXreg"),
py::arg("constant"),
py::arg("horizon"));
}

0 comments on commit c0b48b2

Please sign in to comment.