Skip to content

Commit

Permalink
update floss
Browse files Browse the repository at this point in the history
  • Loading branch information
franzbischoff committed Nov 30, 2023
1 parent f694b52 commit c772fa0
Show file tree
Hide file tree
Showing 13 changed files with 23 additions and 23 deletions.
36 changes: 18 additions & 18 deletions analysis/regime_optimize.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ all_scores <- all_fitted %>%
dplyr::rename(fold = id, size = .sizes, record = .id, model = .config, pred = .pred) %>%
dplyr::distinct(rep, record, across(all_of(predictors_names)), .keep_all = TRUE) %>%
dplyr::mutate(truth = clean_truth(truth, size), pred = clean_pred(pred)) %>%
dplyr::mutate(score = score_regimes_limit(truth, pred, 0, 2500))
dplyr::mutate(score = score_regimes_weighted(truth, pred, 0))

holdout_scores <- outputs$final_evaluation %>%
dplyr::select(
Expand Down Expand Up @@ -611,14 +611,14 @@ if (file.exists(here("output", "importances.rds"))) {
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP importances...")
importance_shap <- check_importance(best_fit, train_data, testing_data[, predictors_names], predictors_names,
type = "shap", nsim = 400, parallel = TRUE
type = "shap", nsim = 100, parallel = TRUE
)
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP explanations...")
shap_fastshap_all_test <- shap_explain(best_fit, train_data[, predictors_names], testing_data[, predictors_names],
predictors_names,
nsim = 400, parallel = TRUE
nsim = 100, parallel = TRUE
)
preds_test <- predict(best_fit, testing_data[, predictors_names])
Expand All @@ -642,13 +642,13 @@ if (file.exists(here("output", "importances.rds"))) {
tree_data2 <- tree_data %>%
dplyr::mutate(
int_mp_w = mp_threshold * window_size,
# int_mp_rt = mp_threshold * regime_threshold
# int_mp_tc = mp_threshold * time_constraint
int_mp_rt = mp_threshold * regime_threshold,
int_w_tc = window_size * time_constraint,
int_tc_rt = time_constraint * regime_threshold,
.before = mean
)
predictor_names_int <- c(predictors_names, "int_mp_w")
predictor_names_int <- c(predictors_names, "int_mp_rt", "int_w_tc", "int_tc_rt")
trained_model2 <- NULL
# Caching ==========
Expand Down Expand Up @@ -676,7 +676,7 @@ if (file.exists(here("output", "importances2.rds"))) {
} else {
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP importances...")
importance_shap2 <- check_importance(best_fit2, train_data2[, predictor_names_int], testing_data2[, predictor_names_int], predictor_names_int,
type = "shap", nsim = 400, parallel = TRUE
type = "shap", nsim = 100, parallel = TRUE
)
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking interactions...")
interactions2 <- check_interactions(best_fit2, train_data2, predictors_names, parallel = TRUE)
Expand All @@ -690,7 +690,7 @@ if (file.exists(here("output", "importances2.rds"))) {
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP explanations...")
shap_fastshap_all_test2 <- shap_explain(best_fit2, train_data2[, predictor_names_int], testing_data2[, predictor_names_int],
predictors_names,
nsim = 400, parallel = TRUE
nsim = 100, parallel = TRUE
)
preds_test2 <- predict(best_fit2, testing_data2[, predictor_names_int])
Expand Down Expand Up @@ -730,7 +730,7 @@ interactions_plot <- ggplot2::ggplot(interactions, ggplot2::aes(
y = ggplot2::element_blank(),
x = ggplot2::element_blank()
) +
# ggplot2::ylim(0, 1.2) +
ggplot2::ylim(0, 0.025) +
ggplot2::theme_bw() +
ggplot2::theme(legend.position = "none")
Expand All @@ -745,7 +745,7 @@ interactions2_plot <- ggplot2::ggplot(interactions2, ggplot2::aes(
y = "Interaction strength",
x = ggplot2::element_blank()
) +
# ggplot2::ylim(0, 1.2) +
ggplot2::ylim(0, 0.025) +
ggplot2::theme_bw() +
ggplot2::theme(legend.position = "none")
Expand Down Expand Up @@ -816,7 +816,7 @@ Fig. \@ref(fig:importance) then shows the variable importance using three method
```{r importance, fig.height = 7, fig.width= 15, out.width="100%", cache=FALSE}
#| fig.cap="Variables importances using three different methods. A) Feature Importance Ranking Measure
#| using ICE curves. B) Permutation method. C) SHAP (400 iterations). Line 1 refers to the original
#| using ICE curves. B) Permutation method. C) SHAP (100 iterations). Line 1 refers to the original
#| fit, and line 2 to the re-fit, taking into account the interactions between variables
#| (Fig. \\@ref(fig:interaction))."
Expand All @@ -827,7 +827,7 @@ importance_firm_plot <- importance_firm +
subtitle = "Individual Conditional Expectation",
y = ggplot2::element_blank()
) +
# ggplot2::ylim(0, 4.5) +
ggplot2::ylim(0, 0.04) +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = "none",
Expand All @@ -848,10 +848,10 @@ importance_perm_plot <- importance_perm +
importance_shap_plot <- importance_shap +
ggplot2::labs(
title = "SHAP (400 iterations)",
title = "SHAP (100 iterations)",
y = ggplot2::element_blank()
) +
# ggplot2::ylim(0, 2.5) +
ggplot2::ylim(0, 0.03) +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = "none",
Expand All @@ -863,7 +863,7 @@ importance_firm2_plot <- importance_firm2 +
ggplot2::labs(
y = "Importance"
) +
# ggplot2::ylim(0, 4.5) +
ggplot2::ylim(0, 0.04) +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = "none",
Expand All @@ -885,7 +885,7 @@ importance_shap2_plot <- importance_shap2 +
ggplot2::labs(
y = "Importance"
) +
# ggplot2::ylim(0, 2.5) +
ggplot2::ylim(0, 0.03) +
ggplot2::theme_bw() +
ggplot2::theme(
legend.position = "none",
Expand All @@ -894,7 +894,7 @@ importance_shap2_plot <- importance_shap2 +
all <- (importance_firm_plot / importance_firm2_plot + plot_layout(tag_level = "new")) |
(importance_perm_plot / importance_firm2_plot + plot_layout(tag_level = "new")) |
(importance_perm_plot / importance_perm2_plot + plot_layout(tag_level = "new")) |
(importance_shap_plot / importance_shap2_plot + plot_layout(tag_level = "new")) +
plot_layout(guides = "collect")
Expand Down
6 changes: 3 additions & 3 deletions analysis/regime_optimize_3.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ all_scores <- all_fitted %>%
dplyr::rename(fold = id, size = .sizes, record = .id, model = .config, pred = .pred) %>%
dplyr::distinct(rep, record, across(all_of(predictors_names)), .keep_all = TRUE) %>%
dplyr::mutate(truth = clean_truth(truth, size), pred = clean_pred(pred)) %>%
dplyr::mutate(score = score_regimes_limit(truth, pred, 0, 2500))
dplyr::mutate(score = score_regimes_weighted(truth, pred, 0))
```


Expand Down Expand Up @@ -540,7 +540,7 @@ if (file.exists(here("output", "importances_mvds.rds"))) {
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP importances...")
importance_shap <- check_importance(best_fit, train_data, testing_data[, predictors_names], predictors_names,
type = "shap", nsim = 400, parallel = TRUE
type = "shap", nsim = 100, parallel = TRUE
)
importance_shap <- ggplot2::ggplot_build(importance_shap)$plot$data
Expand All @@ -559,7 +559,7 @@ if (file.exists(here("output", "importances_mvds.rds"))) {
shap_fastshap_all_test <- shap_explain(best_fit, train_data[, predictors_names], testing_data[, predictors_names],
predictors_names,
nsim = 400, parallel = TRUE
nsim = 100, parallel = TRUE
)
shap_html_test <- NA
# preds_test <- predict(best_fit, testing_data[, predictors_names])
Expand Down
4 changes: 2 additions & 2 deletions analysis/regime_optimize_4.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ all_scores <- all_fitted %>%
dplyr::rename(fold = id, size = .sizes, record = .id, model = .config, pred = .pred) %>%
dplyr::distinct(rep, record, across(all_of(predictors_names)), .keep_all = TRUE) %>%
dplyr::mutate(truth = clean_truth(truth, size), pred = clean_pred(pred)) %>%
dplyr::mutate(score = score_regimes_limit(truth, pred, 0, 2500))
dplyr::mutate(score = score_regimes_weighted(truth, pred, 0))
```


Expand Down Expand Up @@ -540,7 +540,7 @@ if (file.exists(here("output", "importances_vtds.rds"))) {
cli::cli_alert_info("{round.POSIXt(Sys.time())} - Checking SHAP importances...")
importance_shap <- check_importance(best_fit, train_data, testing_data[, predictors_names], predictors_names,
type = "shap", nsim = 400, parallel = TRUE
type = "shap", nsim = 100, parallel = TRUE
)
importance_shap <- ggplot2::ggplot_build(importance_shap)$plot$data
Expand Down
Binary file modified output/dbarts_fitted.rds
Binary file not shown.
Binary file modified output/dbarts_fitted2.rds
Binary file not shown.
Binary file modified output/dbarts_fitted_mvds.rds
Binary file not shown.
Binary file modified output/dbarts_fitted_vtds.rds
Binary file not shown.
Binary file modified output/importances.rds
Binary file not shown.
Binary file modified output/importances2.rds
Binary file not shown.
Binary file modified output/importances_mvds.rds
Binary file not shown.
Binary file modified output/importances_vtds.rds
Binary file not shown.
Binary file modified output/scores_stats_model_3.rds
Binary file not shown.
Binary file modified output/scores_stats_model_4.rds
Binary file not shown.

0 comments on commit c772fa0

Please sign in to comment.