Skip to content

Commit

Permalink
fix xreg for stlm
Browse files Browse the repository at this point in the history
  • Loading branch information
dashaub committed Jul 26, 2017
1 parent 17f51f8 commit 5639179
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
6 changes: 3 additions & 3 deletions pkg/R/forecast.hybridModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ forecast.hybridModel <- function(object,
if("auto.arima" %in% includedModels){
# Only apply the xreg if it was used in the original model
xregA <- xreg
if(is.null(object$auto.arima$xreg)){
if(!object$xreg$auto.arima){
xregA <- NULL
}
forecasts$auto.arima <- forecast(object$auto.arima, h = h, xreg = xregA, level = level, ...)
Expand All @@ -125,7 +125,7 @@ forecast.hybridModel <- function(object,
if("nnetar" %in% includedModels){
# Only apply the xreg if it was used in the original model
xregN <- xreg
if(is.null(object$nnetar$xreg)){
if(!object$xreg$nnetar){
xregN <- NULL
}
forecasts$nnetar <- forecast(object$nnetar, h = h, xreg = xregN, PI = PI, level = level, ...)
Expand All @@ -135,7 +135,7 @@ forecast.hybridModel <- function(object,
# Only apply the xreg if it was used in the original model
xregS <- xreg
# xreg is only used in stlm if method = "arima", and it is stored in slot $model$xreg
if(is.null(object$stlm$model$xreg)){
if(!object$xreg$stlm){
xregS <- NULL
}
forecasts$stlm <- forecast(object$stlm, h = h, xreg = xregS, level = level, ...)
Expand Down
31 changes: 24 additions & 7 deletions pkg/R/hybridModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ hybridModel <- function(y, models = "aefnst",
horizonAverage = FALSE,
parallel = FALSE, num.cores = 2L,
verbose = TRUE){
# Weights could be set to equal (the default), based on in-sample errors, or based on cv errors
# errorMethod will determine which type of errors to use for weights. Some choices from accuracy()
# are not appropriate. If weights = "equal", this would be ignored.

# The dependent variable must be numeric and not a matrix/dataframe
if(!is.numeric(y) || !is.null(dim(y))){
Expand Down Expand Up @@ -212,10 +209,7 @@ hybridModel <- function(y, models = "aefnst",
if(parallel){
warning("The 'parallel' argument is currently unimplemented. Ignoring for now.")
}
# if(weights == "cv.errors"){
# warning("Cross validated error weights are currently unimplemented. Ignoring for now.")
# weights <- "equal"
# }

if(weights == "cv.errors" && errorMethod == "MASE"){
warning("cv errors currently do not support MASE. Reverting to RMSE.")
errorMethod <- "RMSE"
Expand Down Expand Up @@ -407,10 +401,33 @@ hybridModel <- function(y, models = "aefnst",
tsp(fits) <- tsp(resid) <- tsp(y)
}

# Save which models used xreg
xregs <- list()
if("a" %in% expandedModels){
xregs$auto.arima <- FALSE
if("xreg" %in% names(a.args) && !is.null(a.args$xreg)){
xregs$auto.arima <- TRUE
}
}
if("n" %in% expandedModels){
xregs$nnetar <- FALSE
if("xreg" %in% names(n.args) && !is.null(n.args$xreg)){
xregs$nnetar <- TRUE
}
}
if("s" %in% expandedModels){
xregs$stlm <- FALSE
methodArima <- "method" %in% names(s.args) && s.args$method == "arima"
if("xreg" %in% names(s.args) && !is.null(s.args$xreg) && methodArima){
xregs$stlm <- TRUE
}
}

# Prepare the hybridModel object
class(modelResults) <- "hybridModel"
modelResults$frequency <- frequency(y)
modelResults$x <- y
modelResults$xreg <- xregs
modelResults$models <- includedModels
modelResults$fitted <- fits
modelResults$residuals <- resid
Expand Down
10 changes: 5 additions & 5 deletions pkg/tests/testthat/test-forecast.hybridModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ if(require(forecast) & require(testthat)){
expect_error(forecast(object = hModel, h = "a"))
# h should be an integer
expect_error(forecast(object = hModel, h = 3.2))
#~ expect_error(forecast(object = hModel, h = 5,
#~ xreg = matrix(1:5, nrow = 5, ncol = 2)))
# matrix should be numeric
expect_error(forecast(object = hModel, h = 5,
xreg = matrix("a", nrow = 5, ncol = 2)))
Expand All @@ -32,14 +30,16 @@ if(require(forecast) & require(testthat)){
inputSeries <- ts(wineind[1:25], f = frequency(wineind))
mm <- matrix(runif(length(inputSeries)), nrow = length(inputSeries))
# stlm only works with xreg when method = "arima" is passed in s.args
expect_error(aa <- hybridModel(inputSeries, models = "afns", a.args = list(xreg = mm), s.args = list(xreg = mm)))
expect_error(aa <- hybridModel(inputSeries, models = "afns",
a.args = list(xreg = mm),
s.args = list(xreg = mm)))
aa <- hybridModel(inputSeries, models = "aefnst",
a.args = list(xreg = mm),
n.args = list (xreg = mm),
s.args = list(xreg = mm, method = "arima"))
# If xreg is used and no h is provided, overwrite h
newXreg <- matrix(rnorm(20), nrow = 20)
expect_error(tmp <- forecast(aa, xreg = newXreg, npaths = 50), NA)
expect_error(tmp <- forecast(aa, xreg = newXreg, npaths = 5), NA)
# If nrow(xreg) != h, issue a warning but set h <- nrow(xreg)
expect_warning(forecast(aa, h = 10, xreg = newXreg, PI = FALSE))
newXreg <- matrix(rnorm(24), nrow = 24)
Expand All @@ -49,7 +49,7 @@ if(require(forecast) & require(testthat)){
mod <- hybridModel(inputSeries, models = "af", a.args = list(xreg = mm))
expect_error(forecast(mod, xreg = newXreg), NA)
mod <- hybridModel(inputSeries, models = "nf", n.args = list(xreg = mm))
expect_error(forecast(mod, xreg = newXreg), NA)
expect_error(forecast(mod, xreg = newXreg, PI = FALSE), NA)
mod <- hybridModel(inputSeries, models = "sf", s.args = list(xreg = mm, method = "arima"))
expect_error(forecast(mod, xreg = newXreg), NA)

Expand Down

0 comments on commit 5639179

Please sign in to comment.