Skip to content

Commit

Permalink
fixed loading old model & refresh bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gufengzhou committed Nov 3, 2021
1 parent d12dfec commit 6e40335
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
13 changes: 9 additions & 4 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ robyn_run <- function(InputCollect,
message(">>> Collecting results...")

## collect hyperparameter results
names(model_output_collect) <- paste0("trial", 1:InputCollect$trials)
if (hyper_fixed) {
names(model_output_collect) <- "trial1"
} else {
names(model_output_collect) <- paste0("trial", 1:InputCollect$trials)
}

resultHypParam <- rbindlist(lapply(model_output_collect, function(x) x$resultCollect$resultHypParam[, trial := x$trial]))
resultHypParam[, iterations := (iterNG - 1) * InputCollect$cores + iterPar]
xDecompAgg <- rbindlist(lapply(model_output_collect, function(x) x$resultCollect$xDecompAgg[, trial := x$trial]))
Expand Down Expand Up @@ -1063,14 +1068,14 @@ robyn_mmm <- function(hyper_collect,
# assign("InputCollect", InputCollect, envir = .GlobalEnv) # adding this to enable InputCollect reading during parallel
# opts <- list(progress = function(n) setTxtProgressBar(pb, n))
sysTimeDopar <- system.time({
for (lng in 1:iterNG) {
for (lng in 1:iterNG) { # lng = 1
nevergrad_hp <- list()
nevergrad_hp_val <- list()
hypParamSamList <- list()
hypParamSamNG <- c()

if (hyper_fixed == FALSE) {
for (co in 1:iterPar) {
for (co in 1:iterPar) { # co = 1

## get hyperparameter sample with ask
nevergrad_hp[[co]] <- optimizer$ask()
Expand Down Expand Up @@ -1119,7 +1124,7 @@ robyn_mmm <- function(hyper_collect,

getDoParWorkers()
doparCollect <- suppressPackageStartupMessages(
foreach(i = 1:iterPar) %dorng% {
foreach(i = 1:iterPar) %dorng% { # i = 1
t1 <- Sys.time()

#####################################
Expand Down
24 changes: 13 additions & 11 deletions R/R/refresh.R
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ robyn_refresh <- function(robyn_object,
OutputCollectRF$selectID <- selectID
message(
"Selected model ID: ", selectID, " for refresh model nr.",
refreshCounter, " based on the smallest combined error of NRMSE & DECOMP.RSSD"
refreshCounter, " based on the smallest combined error of NRMSE & DECOMP.RSSD\n"
)

OutputCollectRF$resultHypParam[, bestModRF := solID == selectID]
Expand Down Expand Up @@ -334,37 +334,39 @@ robyn_refresh <- function(robyn_object,
OutputCollectRF$mediaVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
][, refreshStatus := refreshCounter]
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
)
mediaVecReport <- mediaVecReport[order(type, ds, refreshStatus)]
xDecompVecReport <- rbind(
listOutputPrev$xDecompVecCollect[bestModRF == TRUE],
OutputCollectRF$xDecompVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
][, refreshStatus := refreshCounter]
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
)
} else {
resultHypParamReport <- rbind(listReportPrev$resultHypParamReport, OutputCollectRF$resultHypParam[
bestModRF == TRUE
][, refreshStatus := refreshCounter])
xDecompAggReport <- rbind(listReportPrev$xDecompAggReport, OutputCollectRF$xDecompAgg[
bestModRF == TRUE
][, refreshStatus := refreshCounter])
resultHypParamReport <- rbind(
listReportPrev$resultHypParamReport,
OutputCollectRF$resultHypParam[bestModRF == TRUE][
, refreshStatus := refreshCounter])
xDecompAggReport <- rbind(
listReportPrev$xDecompAggReport,
OutputCollectRF$xDecompAgg[bestModRF == TRUE][
, refreshStatus := refreshCounter])
mediaVecReport <- rbind(
listReportPrev$mediaVecReport,
OutputCollectRF$mediaVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
][, refreshStatus := refreshCounter]
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
)
mediaVecReport <- mediaVecReport[order(type, ds, refreshStatus)]
xDecompVecReport <- rbind(
listReportPrev$xDecompVecReport,
OutputCollectRF$xDecompVecCollect[
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
ds <= refreshEnd
][, refreshStatus := refreshCounter]
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
)
}

Expand Down
11 changes: 11 additions & 0 deletions demo/debug.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ lambda.n = 100
lambda_control = 1
lambda_fixed = NULL
refresh = FALSE
seed = 123L
# go into robyn_mmm() line by line

## debug robyn_run
Expand All @@ -24,6 +25,16 @@ csv_out = "pareto"
seed = 123
# go into robyn_run() line by line

## debug robyn_refresh
# robyn_object
dt_input = dt_input
dt_holidays = dt_holidays
refresh_steps = 14
refresh_mode = "auto" # "auto", "manual"
refresh_iters = 100
refresh_trials = 2
plot_pareto = TRUE

## debug robyn_allocator
# prep input para

Expand Down

0 comments on commit 6e40335

Please sign in to comment.