Skip to content

Commit

Permalink
fix: deprecated plot_folder_sub to avoid chain-issues
Browse files Browse the repository at this point in the history
  • Loading branch information
laresbernardo committed Oct 5, 2022
1 parent 8398f12 commit 81628de
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 40 deletions.
2 changes: 1 addition & 1 deletion R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ check_robyn_name <- function(robyn_object, quiet = FALSE) {
}
}

check_filedir <- function(plot_folder) {
check_dir <- function(plot_folder) {
file_end <- substr(plot_folder, nchar(plot_folder) - 3, nchar(plot_folder))
if (file_end == ".RDS") {
plot_folder <- dirname(plot_folder)
Expand Down
5 changes: 4 additions & 1 deletion R/R/json.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ robyn_recreate <- function(json_file, quiet = FALSE, ...) {
# Import the whole chain any refresh model to init
robyn_chain <- function(json_file) {
json_data <- robyn_read(json_file, quiet = TRUE)
ids <- c(json_data$InputCollect$refreshChain, json_data$ExportedModel$select_model)
plot_folder <- json_data$ExportedModel$plot_folder
temp <- stringr::str_split(plot_folder, "/")[[1]]
chain <- temp[startsWith(temp, "Robyn_")]
Expand All @@ -278,6 +279,8 @@ robyn_chain <- function(json_file) {
dirs <- sapply(chainData, function(x) x$ExportedModel$plot_folder)
json_files <- paste0(dirs, "RobynModel-", names(dirs), ".json")
attr(chainData, "json_files") <- json_files
attr(chainData, "chain") <- names(chainData)
attr(chainData, "chain") <- ids # names(chainData)
if (length(ids) != length(names(chainData)))
warning("Can't replicate chain-like results if you don't follow Robyn's chain structure")
return(invisible(chainData))
}
38 changes: 15 additions & 23 deletions R/R/outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
#' selection. Lower \code{calibration_constraint} increases calibration accuracy.
#' @param plot_folder Character. Path for saving plots. Default
#' to \code{robyn_object} and saves plot in the same directory as \code{robyn_object}.
#' @param plot_folder_sub Character. Customize sub path to save plots. The total
#' path is created with \code{dir.create(file.path(plot_folder, plot_folder_sub))}.
#' For example, plot_folder_sub = "sub_dir".
#' @param plot_pareto Boolean. Set to \code{FALSE} to deactivate plotting
#' and saving model one-pagers. Used when testing models.
#' @param clusters Boolean. Apply \code{robyn_clusters()} to output models?
Expand All @@ -43,15 +40,15 @@
robyn_outputs <- function(InputCollect, OutputModels,
pareto_fronts = "auto",
calibration_constraint = 0.1,
plot_folder = NULL, plot_folder_sub = NULL,
plot_folder = NULL,
plot_pareto = TRUE,
csv_out = "pareto",
clusters = TRUE,
select_model = "clusters",
ui = FALSE, export = TRUE,
quiet = FALSE, ...) {
if (is.null(plot_folder)) plot_folder <- getwd()
plot_folder <- check_filedir(plot_folder)
plot_folder <- check_dir(plot_folder)

# Check calibration constrains
calibration_constraint <- check_calibconstr(
Expand Down Expand Up @@ -91,6 +88,19 @@ robyn_outputs <- function(InputCollect, OutputModels,
plotDataCollect = pareto_results$plotDataCollect
)

# Set folder to save outputs: legacy plot_folder_sub
if (TRUE) {
depth <- ifelse(
"refreshDepth" %in% names(InputCollect),
InputCollect$refreshDepth,
ifelse("refreshCounter" %in% names(InputCollect),
InputCollect$refreshCounter, 0
)
)
folder_var <- ifelse(!as.integer(depth) > 0, "init", paste0("rf", depth))
plot_folder_sub <- paste("Robyn", format(Sys.time(), "%Y%m%d%H%M"), folder_var, sep = "_")
}

# Final results object
OutputCollect <- list(
resultHypParam = filter(pareto_results$resultHypParam, .data$solID %in% allSolutions),
Expand All @@ -115,24 +125,6 @@ robyn_outputs <- function(InputCollect, OutputModels,
)
class(OutputCollect) <- c("robyn_outputs", class(OutputCollect))

# Set folder to save outputs
if (is.null(plot_folder_sub)) {
refresh <- attr(OutputModels, "refresh")
depth <- ifelse(
"refreshDepth" %in% names(InputCollect),
InputCollect$refreshDepth,
ifelse("refreshCounter" %in% names(InputCollect),
InputCollect$refreshCounter, 0
)
)
refresh <- as.integer(depth) > 0
folder_var <- ifelse(!refresh, "init", paste0("rf", depth))
plot_folder_sub <- paste("Robyn", format(Sys.time(), "%Y%m%d%H%M"), folder_var, sep = "_")
}

plotPath <- paste0(plot_folder, "/", plot_folder_sub, "/")
OutputCollect$plot_folder <- gsub("//", "/", plotPath)

if (export) {
if (!dir.exists(OutputCollect$plot_folder)) dir.create(OutputCollect$plot_folder, recursive = TRUE)
tryCatch(
Expand Down
1 change: 0 additions & 1 deletion R/R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,6 @@ refresh_plots_json <- function(OutputCollectRF, json_file, export = TRUE) {
group_by(.data$solID, .data$label, .data$variable) %>%
summarise_all(sum)

df <- replace(df, is.na(df), 0)
outputs[["pBarRF"]] <- pBarRF <- df %>%
ggplot(aes(y = .data$variable)) +
geom_col(aes(x = .data$decompPer)) +
Expand Down
7 changes: 3 additions & 4 deletions R/R/refresh.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ robyn_refresh <- function(json_file = NULL,
robyn_object = NULL,
dt_input = NULL,
dt_holidays = Robyn::dt_prophet_holidays,
plot_folder_sub = NULL,
refresh_steps = 4,
refresh_mode = "manual",
refresh_iters = 1000,
refresh_trials = 3,
plot_folder = NULL,
plot_pareto = TRUE,
version_prompt = FALSE,
export = TRUE,
Expand All @@ -119,12 +119,12 @@ robyn_refresh <- function(json_file = NULL,
if (!is.null(json_file)) {
Robyn <- list()
json <- robyn_read(json_file, step = 2, quiet = TRUE)
listInit <- robyn_recreate(
listInit <- suppressWarnings(robyn_recreate(
json_file = json_file,
dt_input = dt_input,
dt_holidays = dt_holidays,
quiet = FALSE, ...
)
))
listInit$InputCollect$refreshSourceID <- json$ExportedModel$select_model
chainData <- robyn_chain(json_file)
listInit$InputCollect$refreshChain <- attr(chainData, "chain")
Expand Down Expand Up @@ -277,7 +277,6 @@ robyn_refresh <- function(json_file = NULL,
OutputCollectRF <- robyn_run(
InputCollect = InputCollectRF,
plot_folder = objectPath,
plot_folder_sub = plot_folder_sub,
calibration_constraint = listOutputPrev[["calibration_constraint"]],
add_penalty_factor = listOutputPrev[["add_penalty_factor"]],
iterations = refresh_iters,
Expand Down
5 changes: 0 additions & 5 deletions R/man/robyn_outputs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions R/man/robyn_refresh.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 81628de

Please sign in to comment.