diff --git a/pkg/R/forecast.hybridModel.R b/pkg/R/forecast.hybridModel.R index 2974481..9c75736 100644 --- a/pkg/R/forecast.hybridModel.R +++ b/pkg/R/forecast.hybridModel.R @@ -108,7 +108,7 @@ forecast.hybridModel <- function(object, if("auto.arima" %in% includedModels){ # Only apply the xreg if it was used in the original model xregA <- xreg - if(is.null(object$auto.arima$xreg)){ + if(!object$xreg$auto.arima){ xregA <- NULL } forecasts$auto.arima <- forecast(object$auto.arima, h = h, xreg = xregA, level = level, ...) @@ -125,7 +125,7 @@ forecast.hybridModel <- function(object, if("nnetar" %in% includedModels){ # Only apply the xreg if it was used in the original model xregN <- xreg - if(is.null(object$nnetar$xreg)){ + if(!object$xreg$nnetar){ xregN <- NULL } forecasts$nnetar <- forecast(object$nnetar, h = h, xreg = xregN, PI = PI, level = level, ...) @@ -135,7 +135,7 @@ forecast.hybridModel <- function(object, # Only apply the xreg if it was used in the original model xregS <- xreg # xreg is only used in stlm if method = "arima", and it is stored in slot $model$xreg - if(is.null(object$stlm$model$xreg)){ + if(!object$xreg$stlm){ xregS <- NULL } forecasts$stlm <- forecast(object$stlm, h = h, xreg = xregS, level = level, ...) diff --git a/pkg/R/hybridModel.R b/pkg/R/hybridModel.R index 9b2e04b..364764f 100644 --- a/pkg/R/hybridModel.R +++ b/pkg/R/hybridModel.R @@ -109,9 +109,6 @@ hybridModel <- function(y, models = "aefnst", horizonAverage = FALSE, parallel = FALSE, num.cores = 2L, verbose = TRUE){ - # Weights could be set to equal (the default), based on in-sample errors, or based on cv errors - # errorMethod will determine which type of errors to use for weights. Some choices from accuracy() - # are not appropriate. If weights = "equal", this would be ignored. # The dependent variable must be numeric and not a matrix/dataframe if(!is.numeric(y) || !is.null(dim(y))){ @@ -212,10 +209,7 @@ hybridModel <- function(y, models = "aefnst", if(parallel){ warning("The 'parallel' argument is currently unimplemented. Ignoring for now.") } - # if(weights == "cv.errors"){ - # warning("Cross validated error weights are currently unimplemented. Ignoring for now.") - # weights <- "equal" - # } + if(weights == "cv.errors" && errorMethod == "MASE"){ warning("cv errors currently do not support MASE. Reverting to RMSE.") errorMethod <- "RMSE" @@ -407,10 +401,33 @@ hybridModel <- function(y, models = "aefnst", tsp(fits) <- tsp(resid) <- tsp(y) } + # Save which models used xreg + xregs <- list() + if("a" %in% expandedModels){ + xregs$auto.arima <- FALSE + if("xreg" %in% names(a.args) && !is.null(a.args$xreg)){ + xregs$auto.arima <- TRUE + } + } + if("n" %in% expandedModels){ + xregs$nnetar <- FALSE + if("xreg" %in% names(n.args) && !is.null(n.args$xreg)){ + xregs$nnetar <- TRUE + } + } + if("s" %in% expandedModels){ + xregs$stlm <- FALSE + methodArima <- "method" %in% names(s.args) && s.args$method == "arima" + if("xreg" %in% names(s.args) && !is.null(s.args$xreg) && methodArima){ + xregs$stlm <- TRUE + } + } + # Prepare the hybridModel object class(modelResults) <- "hybridModel" modelResults$frequency <- frequency(y) modelResults$x <- y + modelResults$xreg <- xregs modelResults$models <- includedModels modelResults$fitted <- fits modelResults$residuals <- resid diff --git a/pkg/tests/testthat/test-forecast.hybridModel.R b/pkg/tests/testthat/test-forecast.hybridModel.R index 174b856..e73e942 100644 --- a/pkg/tests/testthat/test-forecast.hybridModel.R +++ b/pkg/tests/testthat/test-forecast.hybridModel.R @@ -11,8 +11,6 @@ if(require(forecast) & require(testthat)){ expect_error(forecast(object = hModel, h = "a")) # h should be an integer expect_error(forecast(object = hModel, h = 3.2)) -#~ expect_error(forecast(object = hModel, h = 5, -#~ xreg = matrix(1:5, nrow = 5, ncol = 2))) # matrix should be numeric expect_error(forecast(object = hModel, h = 5, xreg = matrix("a", nrow = 5, ncol = 2))) @@ -32,14 +30,16 @@ if(require(forecast) & require(testthat)){ inputSeries <- ts(wineind[1:25], f = frequency(wineind)) mm <- matrix(runif(length(inputSeries)), nrow = length(inputSeries)) # stlm only works with xreg when method = "arima" is passed in s.args - expect_error(aa <- hybridModel(inputSeries, models = "afns", a.args = list(xreg = mm), s.args = list(xreg = mm))) + expect_error(aa <- hybridModel(inputSeries, models = "afns", + a.args = list(xreg = mm), + s.args = list(xreg = mm))) aa <- hybridModel(inputSeries, models = "aefnst", a.args = list(xreg = mm), n.args = list (xreg = mm), s.args = list(xreg = mm, method = "arima")) # If xreg is used and no h is provided, overwrite h newXreg <- matrix(rnorm(20), nrow = 20) - expect_error(tmp <- forecast(aa, xreg = newXreg, npaths = 50), NA) + expect_error(tmp <- forecast(aa, xreg = newXreg, npaths = 5), NA) # If nrow(xreg) != h, issue a warning but set h <- nrow(xreg) expect_warning(forecast(aa, h = 10, xreg = newXreg, PI = FALSE)) newXreg <- matrix(rnorm(24), nrow = 24) @@ -49,7 +49,7 @@ if(require(forecast) & require(testthat)){ mod <- hybridModel(inputSeries, models = "af", a.args = list(xreg = mm)) expect_error(forecast(mod, xreg = newXreg), NA) mod <- hybridModel(inputSeries, models = "nf", n.args = list(xreg = mm)) - expect_error(forecast(mod, xreg = newXreg), NA) + expect_error(forecast(mod, xreg = newXreg, PI = FALSE), NA) mod <- hybridModel(inputSeries, models = "sf", s.args = list(xreg = mm, method = "arima")) expect_error(forecast(mod, xreg = newXreg), NA)