Skip to content

Commit

Permalink
fix: removed duplicated results in JSON files + refresh doc typo
Browse files Browse the repository at this point in the history
  • Loading branch information
laresbernardo committed Aug 24, 2022
1 parent 3bca15d commit 17b7206
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 24 deletions.
17 changes: 9 additions & 8 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ robyn_inputs <- function(dt_input = NULL,
check_novar(select(dt_input, -all_of(unused_vars)))

## Collect input
InputCollect <- output <- list(
InputCollect <- list(
dt_input = dt_input,
dt_holidays = dt_holidays,
dt_mod = NULL,
Expand Down Expand Up @@ -320,7 +320,7 @@ robyn_inputs <- function(dt_input = NULL,
if (!is.null(hyperparameters)) {
### Conditional output 1.2
## Running robyn_inputs() for the 1st time & 'hyperparameters' provided --> run robyn_engineering()
output <- robyn_engineering(InputCollect, ...)
InputCollect <- robyn_engineering(InputCollect, ...)
}
} else {
### Use case 2: adding 'hyperparameters' and/or 'calibration_input' using robyn_inputs()
Expand Down Expand Up @@ -351,17 +351,17 @@ robyn_inputs <- function(dt_input = NULL,
## Update & check hyperparameters
if (is.null(InputCollect$hyperparameters)) InputCollect$hyperparameters <- hyperparameters
check_hyperparameters(InputCollect$hyperparameters, InputCollect$adstock, InputCollect$all_media)
output <- robyn_engineering(InputCollect, ...)
InputCollect <- robyn_engineering(InputCollect, ...)
}
}

if (!is.null(json_file)) {
pending <- which(!names(json$InputCollect) %in% output)
output <- append(output, json$InputCollect[pending])
pending <- which(!names(json$InputCollect) %in% names(InputCollect))
InputCollect <- append(InputCollect, json$InputCollect[pending])
}

class(output) <- c("robyn_inputs", class(output))
return(output)
class(InputCollect) <- c("robyn_inputs", class(InputCollect))
return(InputCollect)
}

#' @param x \code{robyn_inputs()} output.
Expand Down Expand Up @@ -760,7 +760,8 @@ prophet_decomp <- function(dt_transform, dt_holidays,
use_weekday <- "weekday" %in% prophet_vars | "weekly.seasonality" %in% prophet_vars

dt_regressors <- bind_cols(recurrence, select(
dt_transform, all_of(c(context_vars, paid_media_spends)))) %>%
dt_transform, all_of(c(context_vars, paid_media_spends))
)) %>%
mutate(ds = as.Date(.data$ds))

prophet_params <- list(
Expand Down
4 changes: 3 additions & 1 deletion R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,9 @@ init_msgs_run <- function(InputCollect, refresh, lambda_control = NULL, quiet =
"refreshDepth" %in% names(InputCollect),
InputCollect$refreshDepth,
ifelse("refreshCounter" %in% names(InputCollect),
InputCollect$refreshCounter, 0))
InputCollect$refreshCounter, 0
)
)
refresh <- as.integer(depth) > 0
message(sprintf(
"%s model is built on rolling window of %s %s: %s to %s",
Expand Down
4 changes: 3 additions & 1 deletion R/R/outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ robyn_outputs <- function(InputCollect, OutputModels,
"refreshDepth" %in% names(InputCollect),
InputCollect$refreshDepth,
ifelse("refreshCounter" %in% names(InputCollect),
InputCollect$refreshCounter, 0))
InputCollect$refreshCounter, 0
)
)
refresh <- as.integer(depth) > 0
folder_var <- ifelse(!refresh, "init", paste0("rf", depth))
plot_folder_sub <- paste("Robyn", format(Sys.time(), "%Y%m%d%H%M"), folder_var, sep = "_")
Expand Down
4 changes: 2 additions & 2 deletions R/R/refresh.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
#' data("dt_simulated_weekly")
#' data("dt_prophet_holidays")
#' # Set the (pre-trained and exported) Robyn model JSON file
#' json_file <- "~/Robyn_202208081444_init/RobynModel-2_55_4.json
#' json_file <- "~/Robyn_202208081444_init/RobynModel-2_55_4.json"
#'
#' # Run \code{robyn_refresh()} with 13 weeks cadence in auto mode
#' Robyn <- robyn_refresh(
Expand All @@ -81,7 +81,7 @@
#' )
#'
#' # Run \code{robyn_refresh()} with 4 weeks cadence in manual mode
#' json_file2 <- "~/Robyn_202208081444_init/Robyn_202208090847_rf/RobynModel-1_2_3.json
#' json_file2 <- "~/Robyn_202208081444_init/Robyn_202208090847_rf/RobynModel-1_2_3.json"
#' Robyn <- robyn_refresh(
#' json_file = json_file2,
#' dt_input = dt_simulated_weekly,
Expand Down
19 changes: 11 additions & 8 deletions R/R/response.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,15 @@ robyn_response <- function(InputCollect = NULL,
### Use previously exported model using json_file
if (!is.null(json_file)) {
if (is.null(InputCollect)) InputCollect <- robyn_inputs(json_file = json_file, ...)
if (is.null(OutputCollect)) OutputCollect <- robyn_run(
InputCollect = InputCollect,
json_file = json_file,
export = FALSE,
quiet = quiet,
...
)
if (is.null(OutputCollect)) {
OutputCollect <- robyn_run(
InputCollect = InputCollect,
json_file = json_file,
export = FALSE,
quiet = quiet,
...
)
}
if (is.null(dt_hyppar)) dt_hyppar <- OutputCollect$resultHypParam
if (is.null(dt_coef)) dt_coef <- OutputCollect$xDecompAgg
} else {
Expand Down Expand Up @@ -140,8 +142,9 @@ robyn_response <- function(InputCollect = NULL,
}
}

if ("selectID" %in% names(OutputCollect))
if ("selectID" %in% names(OutputCollect)) {
select_model <- OutputCollect$selectID
}

## Prep environment
if (TRUE) {
Expand Down
4 changes: 2 additions & 2 deletions R/man/robyn_refresh.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions demo/demo.R
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ RobynRefresh <- robyn_refresh(
dt_input = dt_simulated_weekly,
dt_holidays = dt_prophet_holidays,
refresh_steps = 13,
refresh_iters = 1000, # 1k is estimation. Use refresh_mode = "manual" to try out.
refresh_iters = 1000, # 1k is an estimation
refresh_trials = 1
)

Expand All @@ -447,7 +447,7 @@ RobynRefresh <- robyn_refresh(
dt_input = dt_simulated_weekly,
dt_holidays = dt_prophet_holidays,
refresh_steps = 7,
refresh_iters = 1000, # 1k is estimation. Use refresh_mode = "manual" to try out.
refresh_iters = 1000, # 1k is an estimation
refresh_trials = 1
)

Expand Down

0 comments on commit 17b7206

Please sign in to comment.