Skip to content

Commit

Permalink
refactor: adapted to dynamic convergence rules
Browse files Browse the repository at this point in the history
  • Loading branch information
gufengzhou committed Feb 22, 2022
1 parent 50f559b commit 95091c9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 34 deletions.
63 changes: 40 additions & 23 deletions R/R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@
#'
#' @param OutputModels List. Output from \code{robyn_run()}
#' @param n_cuts Integer. Default to 20 (5% cuts). Convergence is calculated
#' on using first and last quantile cuts. Criteria 1: last quantile's sd
#' < threshold_sd. Criteria 2: last quantile's median < first quantile's
#' median - 2 * sd. Both have to happen to consider convergence.
#' @param threshold_sd Numeric. Default to 0.025 that is empirically derived.
#' on using first and last quantile cuts. By default, criteria 1: last
#' quantile's sd < first 3 quantiles' mean sd. Criteria 2: last quantile's
#' median < first quantile's median - 3 * first 3 quantiles' mean sd. Both
#' have to be satisfied to consider convergence.
#' @param sd_qtref Integer. Reference quantile of the error convergence rule
#' for standard deviation. Defaults to 3. Error convergence rule for sd is
#' defined as by default: last quantile's sd < first 3 quantiles' mean sd.
#' @param med_lowb Integer. Lower bound distance of the error convergence rule
#' for median. Default to 3. Error convergence rule for median is defined as
#' by default: last quantile's median < first quantile's median - 3 * first 3
#' quantiles' mean sd.
#' @param ... Additional parameters
#' @examples
#' \dontrun{
#' OutputModels <- robyn_converge(
#' OutputModels = OutputModels,
#' n_cuts = 10,
#' threshold_sd = 0.025
#' n_cuts = 20,
#' sd_qtref = 3,
#' med_lowb = 3
#' )
#' }
#' @export
robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) {
robyn_converge <- function(OutputModels, n_cuts = 20, sd_qtref = 3, med_lowb = 3, ...) {

# Gather all trials
get_lists <- as.logical(grepl("trial", names(OutputModels)) * sapply(OutputModels, is.list))
Expand Down Expand Up @@ -54,8 +62,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
))

# Calculate sd and median on each cut to alert user on:
# 1) last quantile's sd < threshold_sd
# 2) last quantile's median < first quantile's median - 2 * sd
# 1) last quantile's sd < mean sd of default first 3 qt
# 2) last quantile's median < median of first qt - default 3 * mean sd of defualt first 3 qt
errors <- dt_objfunc_cvg %>%
group_by(.data$error_type, .data$cuts) %>%
summarise(
Expand All @@ -66,29 +74,37 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
) %>%
group_by(.data$error_type) %>%
mutate(
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2)),
flag_sd = .data$std > threshold_sd
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2))
) %>%
group_by(.data$error_type) %>%
mutate(flag_med = dplyr::last(.data$median[1]) < dplyr::first(.data$median[2]) - 2 * dplyr::first(.data$std))
mutate(first_med = dplyr::first(.data$median),
first_med_avg = mean(.data$median[1:sd_qtref]),
last_med = dplyr::last(.data$median),
first_sd = dplyr::first(.data$std),
first_sd_avg = mean(.data$std[1:sd_qtref]),
last_sd = dplyr::last(.data$std)) %>%
mutate(med_thres = .data$first_med - med_lowb * .data$first_sd_avg,
flag_med = .data$median < .data$first_med - med_lowb * .data$first_sd_avg,
flag_sd = .data$std < .data$first_sd_avg)

conv_msg <- NULL
for (obj_fun in unique(errors$error_type)) {
temp.df <- filter(errors, .data$error_type == obj_fun) %>%
mutate(median = signif(median, 2))
last.qt <- tail(temp.df, 1)
temp <- glued(paste(
"{error_type} {did}converged: sd {sd} @qt.{quantile} {symb_sd} {sd_threh} &",
"med {qtn_median} @qt.{quantile} {symb_med} {med_threh} med@qt.1-2*sd"),
"{error_type} {did}converged: sd@qt.{quantile} {sd} {symb_sd} {sd_threh} &",
"med@qt.{quantile} {qtn_median} {symb_med} {med_threh} med@qt.1-{med_lowb}*sd"),
error_type = last.qt$error_type,
did = ifelse(last.qt$flag_sd | last.qt$flag_med, "NOT ", ""),
sd = signif(last.qt$std, 1),
symb_sd = ifelse(last.qt$flag_sd, ">", "<="),
sd_threh = threshold_sd,
quantile = round(100/n_cuts),
qtn_median = temp.df$median[n_cuts],
symb_med = ifelse(last.qt$flag_med, ">", "<="),
med_threh = signif(temp.df$median[1] - 2 * temp.df$std[1], 2)
did = ifelse(last.qt$flag_sd & last.qt$flag_med, "", "NOT "),
sd = signif(last.qt$last_sd, 2),
symb_sd = ifelse(last.qt$flag_sd, "<", ">="),
sd_threh = signif(last.qt$first_sd_avg, 2),
quantile = n_cuts,
qtn_median = signif(last.qt$last_med, 2),
symb_med = ifelse(last.qt$flag_med, "<", ">="),
med_threh = signif(last.qt$med_thres, 2),
med_lowb = med_lowb
)
conv_msg <- c(conv_msg, temp)
}
Expand Down Expand Up @@ -162,7 +178,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
errors = errors,
conv_msg = conv_msg
)
attr(cvg_out, "threshold_sd") <- threshold_sd
attr(cvg_out, "sd_qtref") <- sd_qtref
attr(cvg_out, "med_lowb") <- med_lowb

return(invisible(cvg_out))
}
1 change: 0 additions & 1 deletion R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ robyn_run <- function(InputCollect,
#' @export
print.robyn_models <- function(x, ...) {
is_fixed <- all(lapply(x$hyper_updated, length) == 1)
threshold_sd <- attr(x$convergence, "threshold_sd")
print(glued(
"
Total trials: {x$trials}
Expand Down
23 changes: 16 additions & 7 deletions R/man/robyn_converge.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions demo/demo.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ hyperparameters <- list(

,ooh_S_alphas = c(0.5, 3)
,ooh_S_gammas = c(0.3, 1)
,ooh_S_thetas = c(0.3) # (0.1, 0.4)
,ooh_S_thetas = c(0.1, 0.4)

,newsletter_alphas = c(0.5, 3)
,newsletter_gammas = c(0.3, 1)
Expand Down Expand Up @@ -276,9 +276,10 @@ OutputModels <- robyn_run(
)
print(OutputModels)

## Check MOO (multi-objective optimisation) convergence
## Check MOO (multi-objective optimisation) convergence plots
OutputModels$convergence$moo_distrb_plot
OutputModels$convergence$moo_cloud_plot
# check convergence rules ?robyn_converge

## Calculate Pareto optimality, cluster and export results and plots. See ?robyn_outputs
OutputCollect <- robyn_outputs(
Expand Down Expand Up @@ -306,7 +307,7 @@ print(OutputCollect)
# , plot_pareto = TRUE
# , plot_folder = robyn_object
# )
# convergence <- robyn_converge(OutputModels, n_cuts = 20, threshold_sd = 0.025)
# convergence <- robyn_converge(OutputModels)
# convergence$moo_distrb_plot
# convergence$moo_cloud_plot
# print(OutputCollect)
Expand Down

0 comments on commit 95091c9

Please sign in to comment.