From cdf58901e1b44c2b17f74de4ecb28f16b12166ff Mon Sep 17 00:00:00 2001 From: etiennebacher Date: Fri, 17 Mar 2023 12:50:22 +0100 Subject: [PATCH 1/3] add methods for fixest_multi --- NAMESPACE | 14 ++++++++++++++ R/find_formula.R | 5 +++++ R/find_parameters.R | 10 ++++++++++ R/find_predictors.R | 6 ++++++ R/find_response.R | 9 +++++++++ R/find_terms.R | 5 +++++ R/get_df.R | 5 +++++ R/get_df_residual.r | 5 +++++ R/get_parameters.R | 5 ++++- R/get_predicted_fixedeffects.R | 5 +++++ R/get_statistic.R | 5 +++++ R/get_varcov.R | 8 ++++++++ R/link_function.R | 5 +++++ R/link_inverse.R | 4 ++++ R/model_info.R | 5 +++++ R/n_obs.R | 4 ++++ 16 files changed, 99 insertions(+), 1 deletion(-) diff --git a/NAMESPACE b/NAMESPACE index 6c54f6a9aa..a76d706229 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -113,6 +113,7 @@ S3method(find_formula,feglm) S3method(find_formula,feis) S3method(find_formula,felm) S3method(find_formula,fixest) +S3method(find_formula,fixest_multi) S3method(find_formula,gam) S3method(find_formula,gamlss) S3method(find_formula,gamm) @@ -222,6 +223,7 @@ S3method(find_parameters,default) S3method(find_parameters,deltaMethod) S3method(find_parameters,emmGrid) S3method(find_parameters,emm_list) +S3method(find_parameters,fixest_multi) S3method(find_parameters,flexsurvreg) S3method(find_parameters,gam) S3method(find_parameters,gamlss) @@ -308,12 +310,14 @@ S3method(find_predictors,afex_aov) S3method(find_predictors,bfsl) S3method(find_predictors,default) S3method(find_predictors,fixest) +S3method(find_predictors,fixest_multi) S3method(find_predictors,logitr) S3method(find_predictors,selection) S3method(find_random,afex_aov) S3method(find_random,default) S3method(find_response,bfsl) S3method(find_response,default) +S3method(find_response,fixest_multi) S3method(find_response,joint) S3method(find_response,logitr) S3method(find_response,mediate) @@ -324,6 +328,7 @@ S3method(find_terms,afex_aov) S3method(find_terms,aovlist) S3method(find_terms,bfsl) S3method(find_terms,default) +S3method(find_terms,fixest_multi) S3method(find_terms,mipo) S3method(find_weights,brmsfit) S3method(find_weights,default) @@ -488,6 +493,7 @@ S3method(get_df,default) S3method(get_df,emmGrid) S3method(get_df,emm_list) S3method(get_df,fixest) +S3method(get_df,fixest_multi) S3method(get_df,lme) S3method(get_df,lmerMod) S3method(get_df,lmerModTest) @@ -593,6 +599,7 @@ S3method(get_parameters,deltaMethod) S3method(get_parameters,emmGrid) S3method(get_parameters,emm_list) S3method(get_parameters,epi.2by2) +S3method(get_parameters,fixest_multi) S3method(get_parameters,flexsurvreg) S3method(get_parameters,gam) S3method(get_parameters,gamlss) @@ -694,6 +701,7 @@ S3method(get_predicted,default) S3method(get_predicted,fa) S3method(get_predicted,faMain) S3method(get_predicted,fixest) +S3method(get_predicted,fixest_multi) S3method(get_predicted,gam) S3method(get_predicted,gamlss) S3method(get_predicted,gamm) @@ -791,6 +799,7 @@ S3method(get_statistic,epi.2by2) S3method(get_statistic,ergm) S3method(get_statistic,feis) S3method(get_statistic,fixest) +S3method(get_statistic,fixest_multi) S3method(get_statistic,flac) S3method(get_statistic,flexsurvreg) S3method(get_statistic,flic) @@ -913,6 +922,7 @@ S3method(get_varcov,crr) S3method(get_varcov,default) S3method(get_varcov,feis) S3method(get_varcov,fixest) +S3method(get_varcov,fixest_multi) S3method(get_varcov,flac) S3method(get_varcov,flexsurvreg) S3method(get_varcov,flic) @@ -1034,6 +1044,7 @@ S3method(link_function,feglm) S3method(link_function,feis) S3method(link_function,felm) S3method(link_function,fixest) +S3method(link_function,fixest_multi) S3method(link_function,flac) S3method(link_function,flexsurvreg) S3method(link_function,flic) @@ -1151,6 +1162,7 @@ S3method(link_inverse,feglm) S3method(link_inverse,feis) S3method(link_inverse,felm) S3method(link_inverse,fixest) +S3method(link_inverse,fixest_multi) S3method(link_inverse,flac) S3method(link_inverse,flexsurvreg) S3method(link_inverse,flic) @@ -1281,6 +1293,7 @@ S3method(model_info,feglm) S3method(model_info,feis) S3method(model_info,felm) S3method(model_info,fixest) +S3method(model_info,fixest_multi) S3method(model_info,flac) S3method(model_info,flexsurvreg) S3method(model_info,flic) @@ -1418,6 +1431,7 @@ S3method(n_obs,feglm) S3method(n_obs,feis) S3method(n_obs,felm) S3method(n_obs,fixest) +S3method(n_obs,fixest_multi) S3method(n_obs,flexsurvreg) S3method(n_obs,gam) S3method(n_obs,gamm) diff --git a/R/find_formula.R b/R/find_formula.R index 25f98e78e0..5f1b19fe0c 100644 --- a/R/find_formula.R +++ b/R/find_formula.R @@ -876,6 +876,11 @@ find_formula.fixest <- function(x, verbose = TRUE, ...) { .find_formula_return(f, verbose = verbose) } +#' @export +find_formula.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, find_formula.fixest, verbose, ...) +} + #' @export diff --git a/R/find_parameters.R b/R/find_parameters.R index 8777794a97..c44dcaa1ab 100644 --- a/R/find_parameters.R +++ b/R/find_parameters.R @@ -813,6 +813,16 @@ find_parameters.nls <- function(x, } } +#' @export +find_parameters.fixest_multi <- function(x, + component = c("all", "conditional", "nonlinear"), + flatten = FALSE, + ...) { + lapply(x, find_parameters.default, component, flatten, ...) +} + + + # helper ---------------------------- .filter_parameters <- function(l, effects, component = "all", flatten, recursive = TRUE) { diff --git a/R/find_predictors.R b/R/find_predictors.R index 21a9a4322c..25f5ce5091 100644 --- a/R/find_predictors.R +++ b/R/find_predictors.R @@ -195,6 +195,12 @@ find_predictors.fixest <- function(x, flatten = FALSE, ...) { } +#' @export +find_predictors.fixest_multi <- function(x, flatten = FALSE, ...) { + lapply(x, find_predictors.fixest, flatten, ...) +} + + #' @export find_predictors.bfsl <- function(x, flatten = FALSE, verbose = TRUE, ...) { l <- list(conditional = "x") diff --git a/R/find_response.R b/R/find_response.R index 847d31ce76..26b6bc6dfe 100644 --- a/R/find_response.R +++ b/R/find_response.R @@ -153,6 +153,15 @@ find_response.joint <- function(x, } +#' @export +find_response.fixest_multi <- function(x, + combine = TRUE, + component = c("conditional", "survival", "all"), + ...) { + lapply(x, find_response.default, combine, component, ...) +} + + # utils --------------------- diff --git a/R/find_terms.R b/R/find_terms.R index 0df1d40187..862c347614 100644 --- a/R/find_terms.R +++ b/R/find_terms.R @@ -135,6 +135,11 @@ find_terms.bfsl <- function(x, flatten = FALSE, verbose = TRUE, ...) { } } +#' @export +find_terms.fixest_multi <- function(x, flatten = FALSE, verbose = TRUE, ...) { + lapply(x, find_terms.default, flatten, verbose) +} + # unsupported ------------------ diff --git a/R/get_df.R b/R/get_df.R index 97b8fc7298..8a089d74e5 100644 --- a/R/get_df.R +++ b/R/get_df.R @@ -244,6 +244,11 @@ get_df.fixest <- function(x, type = "residual", ...) { fixest::degrees_freedom(x, type = type) } +#' @export +get_df.fixest_multi <- function(x, type = "residual", ...) { + lapply(x, get_df.fixest, type, ...) +} + # Mixed models - special treatment -------------- diff --git a/R/get_df_residual.r b/R/get_df_residual.r index 1d8eaceb57..c6ced9cbd0 100644 --- a/R/get_df_residual.r +++ b/R/get_df_residual.r @@ -91,6 +91,11 @@ fixest::degrees_freedom(x, type = "resid") } +#' @keywords internal +.degrees_of_freedom_residual.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, .degrees_of_freedom_residual.fixest, verbose, ...) +} + #' @keywords internal .degrees_of_freedom_residual.summary.lm <- function(x, verbose = TRUE, ...) { x$fstatistic[3] diff --git a/R/get_parameters.R b/R/get_parameters.R index 2a2aa9fb6e..8b098e1a95 100644 --- a/R/get_parameters.R +++ b/R/get_parameters.R @@ -777,7 +777,10 @@ get_parameters.pgmm <- function(x, component = c("conditional", "all"), ...) { text_remove_backticks(params) } - +#' @export +get_parameters.fixest_multi <- function(x, component = c("conditional", "all"), ...) { + lapply(x, get_parameters.default, component, ...) +} # utility functions --------------------------------- diff --git a/R/get_predicted_fixedeffects.R b/R/get_predicted_fixedeffects.R index 5b046db800..e150410407 100644 --- a/R/get_predicted_fixedeffects.R +++ b/R/get_predicted_fixedeffects.R @@ -36,3 +36,8 @@ get_predicted.fixest <- function(x, predict = "expectation", data = NULL, ...) { .get_predicted_out(predictions, args = args, ci_data = NULL) } + +#' @export +get_predicted.fixest_multi <- function(x, predict = "expectation", data = NULL, ...) { + lapply(x, get_predicted.fixest, predict, data, ...) +} \ No newline at end of file diff --git a/R/get_statistic.R b/R/get_statistic.R index 5cefb10db5..bc0425cdce 100644 --- a/R/get_statistic.R +++ b/R/get_statistic.R @@ -2030,6 +2030,11 @@ get_statistic.fixest <- function(x, ...) { out } +#' @export +get_statistic.fixest_multi <- function(x, ...) { + lapply(x, get_statistic.fixest, ...) +} + #' @export diff --git a/R/get_varcov.R b/R/get_varcov.R index 70592b490f..6b6ef35fed 100644 --- a/R/get_varcov.R +++ b/R/get_varcov.R @@ -136,6 +136,14 @@ get_varcov.fixest <- function(x, do.call("FUN", args) } +#' @export +get_varcov.fixest_multi <- function(x, + vcov = NULL, + vcov_args = NULL, + ...) { + lapply(x, get_varcov.fixest, vcov, vcov_args, ...) +} + # mlm --------------------------------------------- diff --git a/R/link_function.R b/R/link_function.R index be2081959c..9144e6f2da 100644 --- a/R/link_function.R +++ b/R/link_function.R @@ -463,6 +463,11 @@ link_function.fixest <- function(x, ...) { #' @export link_function.feglm <- link_function.fixest +#' @export +link_function.fixest_multi <- function(x, ...) { + lapply(x, link_function.fixest, ...) +} + #' @export link_function.glmx <- function(x, ...) { diff --git a/R/link_inverse.R b/R/link_inverse.R index 87bbca323b..bb143c9892 100644 --- a/R/link_inverse.R +++ b/R/link_inverse.R @@ -454,6 +454,10 @@ link_inverse.fixest <- function(x, ...) { #' @export link_inverse.feglm <- link_inverse.fixest +#' @export +link_inverse.fixest_multi <- function(x, ...) { + lapply(x, link_inverse.fixest_multi, ...) +} #' @export link_inverse.glmx <- function(x, ...) { diff --git a/R/model_info.R b/R/model_info.R index 3e83bd4c92..fae5cf92a4 100644 --- a/R/model_info.R +++ b/R/model_info.R @@ -542,6 +542,11 @@ model_info.fixest <- function(x, verbose = TRUE, ...) { #' @export model_info.feglm <- model_info.fixest +#' @export +model_info.fixest_multi <- function(x, verbose = TRUE, ...) { + lapply(x, model_info.fixest, verbose, ...) +} + # Survival-models ---------------------------------------- diff --git a/R/n_obs.R b/R/n_obs.R index c03b49e381..e0dbe1a47f 100644 --- a/R/n_obs.R +++ b/R/n_obs.R @@ -562,6 +562,10 @@ n_obs.fixest <- function(x, ...) { x$nobs } +#' @export +n_obs.fixest_multi <- function(x, ...) { + lapply(x, n_obs.fixest, ...) +} #' @export From 1a92aa831296606b58694663ee92de3c56609c3d Mon Sep 17 00:00:00 2001 From: etiennebacher Date: Fri, 17 Mar 2023 13:59:20 +0100 Subject: [PATCH 2/3] add missing methods, add tests --- NAMESPACE | 6 + R/find_offset.R | 10 + R/find_parameters.R | 2 +- R/find_statistic.R | 10 +- R/find_variables.R | 18 ++ R/link_inverse.R | 2 +- tests/testthat/test-fixest.R | 351 +++++++++++++++++++++++++++++++++++ 7 files changed, 396 insertions(+), 3 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index a76d706229..e18aa00b54 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -181,6 +181,8 @@ S3method(find_formula,wbm) S3method(find_formula,zcpglm) S3method(find_formula,zeroinfl) S3method(find_formula,zerotrunc) +S3method(find_offset,default) +S3method(find_offset,fixest_multi) S3method(find_parameters,BBmm) S3method(find_parameters,BBreg) S3method(find_parameters,BFBayesFactor) @@ -324,12 +326,16 @@ S3method(find_response,mediate) S3method(find_response,mjoint) S3method(find_response,model_fit) S3method(find_response,selection) +S3method(find_statistic,default) +S3method(find_statistic,fixest_multi) S3method(find_terms,afex_aov) S3method(find_terms,aovlist) S3method(find_terms,bfsl) S3method(find_terms,default) S3method(find_terms,fixest_multi) S3method(find_terms,mipo) +S3method(find_variables,default) +S3method(find_variables,fixest_multi) S3method(find_weights,brmsfit) S3method(find_weights,default) S3method(find_weights,gls) diff --git a/R/find_offset.R b/R/find_offset.R index 3057ea96d6..e043336cf2 100644 --- a/R/find_offset.R +++ b/R/find_offset.R @@ -27,6 +27,11 @@ #' } #' @export find_offset <- function(x) { + UseMethod("find_offset") +} + +#' @export +find_offset.default <- function(x) { terms <- .safe( as.character(attributes(stats::terms(find_formula(x)[[1]]))$variables), find_terms(x) @@ -57,3 +62,8 @@ find_offset <- function(x) { offset } + +#' @export +find_offset.fixest_multi <- function(x) { + lapply(x, find_offset.default) +} \ No newline at end of file diff --git a/R/find_parameters.R b/R/find_parameters.R index c44dcaa1ab..5c48e29a46 100644 --- a/R/find_parameters.R +++ b/R/find_parameters.R @@ -818,7 +818,7 @@ find_parameters.fixest_multi <- function(x, component = c("all", "conditional", "nonlinear"), flatten = FALSE, ...) { - lapply(x, find_parameters.default, component, flatten, ...) + lapply(x, find_parameters.default, component = component, flatten = flatten, ...) } diff --git a/R/find_statistic.R b/R/find_statistic.R index 2215ed8887..9ab8641df8 100644 --- a/R/find_statistic.R +++ b/R/find_statistic.R @@ -19,6 +19,11 @@ #' find_statistic(m) #' @export find_statistic <- function(x, ...) { + UseMethod("find_statistic") +} + +#' @export +find_statistic.default <- function(x, ...) { # model object check -------------------------------------------------------- # check if the object is a model object; if not, quit early @@ -339,7 +344,10 @@ find_statistic <- function(x, ...) { } } - +#' @export +find_statistic.fixest_multi <- function(x, ...) { + lapply(x, find_statistic.default, ...) +} diff --git a/R/find_variables.R b/R/find_variables.R index 1a8356e06f..1afeb2232e 100644 --- a/R/find_variables.R +++ b/R/find_variables.R @@ -59,6 +59,15 @@ find_variables <- function(x, component = "all", flatten = FALSE, verbose = TRUE) { + UseMethod("find_variables") +} + +#' @export +find_variables.default <- function(x, + effects = "all", + component = "all", + flatten = FALSE, + verbose = TRUE) { effects <- match.arg(effects, choices = c("all", "fixed", "random")) component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms")) @@ -84,3 +93,12 @@ find_variables <- function(x, c(list(response = resp), pr) } } + +#' @export +find_variables.fixest_multi <- function(x, + effects = "all", + component = "all", + flatten = FALSE, + verbose = TRUE) { + lapply(x, find_variables.default, effects, component, flatten, verbose) +} \ No newline at end of file diff --git a/R/link_inverse.R b/R/link_inverse.R index bb143c9892..0dbfb3f34f 100644 --- a/R/link_inverse.R +++ b/R/link_inverse.R @@ -456,7 +456,7 @@ link_inverse.feglm <- link_inverse.fixest #' @export link_inverse.fixest_multi <- function(x, ...) { - lapply(x, link_inverse.fixest_multi, ...) + lapply(x, link_inverse.fixest, ...) } #' @export diff --git a/tests/testthat/test-fixest.R b/tests/testthat/test-fixest.R index e0d411ca0f..af6929e7bb 100644 --- a/tests/testthat/test-fixest.R +++ b/tests/testthat/test-fixest.R @@ -337,3 +337,354 @@ test_that("find_predictors with i(f1, i.f2) interaction", { ignore_attr = TRUE ) }) + + + +# fixest_multi ------------------------------- + + +m1 <- femlm(c(dist_km, Euros) ~ log(dist_km) | Origin + Destination + Product, data = trade) +m2 <- femlm(c(log1p(dist_km), log1p(Euros)) ~ log(dist_km) | Origin + Destination + Product, data = trade, family = "gaussian") +m3 <- feglm(c(dist_km, Euros) ~ log(dist_km) | Origin + Destination + Product, data = trade, family = "poisson") +m4 <- feols( + c(Sepal.Width, Petal.Length) ~ 1 | Species | Sepal.Length ~ Petal.Width, + data = iris +) + +test_that("fixest_multi: robust variance-covariance", { + mod <- feols(c(mpg, am) ~ hp + drat | cyl, data = mtcars) + # default is clustered + expect_equal( + sqrt(diag(vcov(mod[[1]]))), + sqrt(diag(get_varcov(mod, vcov = ~cyl)[[1]])), + tolerance = 1e-5, + ignore_attr = TRUE + ) + + # HC1 + expect_equal( + sqrt(diag(vcov(mod[[1]], vcov = "HC1"))), + sqrt(diag(get_varcov(mod, vcov = "HC1")[[1]])), + tolerance = 1e-5, + ignore_attr = TRUE + ) + + expect_true(all( + sqrt(diag(vcov(mod[[1]]))) != + sqrt(diag(get_varcov(mod, vcov = "HC1")[[1]])) + )) +}) + + +test_that("fixest_multi: offset", { + # need fix in fixest first: https://github.com/lrberge/fixest/issues/405 + + # tmp <- feols(c(mpg, am) ~ hp, offset = ~ log(qsec), data = mtcars) + # expect_identical(find_offset(tmp)[[1]], "qsec") + # tmp <- feols(c(mpg, am) ~ hp, offset = ~qsec, data = mtcars) + # expect_identical(find_offset(tmp)[[1]], "qsec") +}) + + +test_that("fixest_multi: model_info", { + expect_true(model_info(m1)[[2]]$is_count) + expect_true(model_info(m2)[[2]]$is_linear) + expect_true(model_info(m3)[[2]]$is_count) +}) + +test_that("fixest_multi: find_predictors", { + expect_identical( + find_predictors(m1)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m2)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m3)[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m4)[[1]], + list( + conditional = c("Sepal.Length"), cluster = "Species", + instruments = "Petal.Width", endogenous = "Sepal.Length" + ) + ) + expect_identical( + find_predictors(m1, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m2, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m3, component = "all")[[2]], + list(conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_predictors(m4, component = "all")[[2]], + list( + conditional = c("Sepal.Length"), + cluster = "Species", + instruments = "Petal.Width", + endogenous = "Sepal.Length" + ) + ) +}) + +test_that("fixest_multi: find_random", { + expect_null(find_random(m1)) + expect_null(find_random(m2)) + expect_null(find_random(m3)) +}) + +test_that("fixest_multi: get_varcov", { + expect_equal(vcov(m1[[1]]), get_varcov(m1)[[1]], tolerance = 1e-3) + expect_equal(vcov(m4[[1]]), get_varcov(m4)[[1]], tolerance = 1e-3) +}) + +test_that("fixest_multi: get_random", { + expect_warning(expect_null(get_random(m1))) +}) + +test_that("fixest_multi: find_response", { + expect_identical(find_response(m1)[[2]], "Euros") + expect_identical(find_response(m2)[[2]], "Euros") + expect_identical(find_response(m3)[[2]], "Euros") +}) + +test_that("fixest_multi: get_response", { + # expect_equal(get_response(m1)[[2]], trade$Euros, ignore_attr = TRUE) + # expect_equal(get_response(m2)[[2]], trade$Euros, ignore_attr = TRUE) + # expect_equal(get_response(m3)[[2]], trade$Euros, ignore_attr = TRUE) +}) + +test_that("fixest_multi: get_predictors", { + # expect_identical(colnames(get_predictors(m1)), c("dist_km", "Origin", "Destination", "Product")) + # expect_identical(colnames(get_predictors(m2)), c("dist_km", "Origin", "Destination", "Product")) + # expect_identical(colnames(get_predictors(m3)), c("dist_km", "Origin", "Destination", "Product")) +}) + +test_that("fixest_multi: link_inverse", { + expect_equal(link_inverse(m1[[1]])(0.2), exp(0.2), tolerance = 1e-4) + expect_equal(link_inverse(m2[[1]])(0.2), 0.2, tolerance = 1e-4) + expect_equal(link_inverse(m3[[1]])(0.2), exp(0.2), tolerance = 1e-4) +}) + +test_that("fixest_multi: link_function", { + expect_equal(link_function(m1[[1]])(0.2), log(0.2), tolerance = 1e-4) + expect_equal(link_function(m2[[1]])(0.2), 0.2, tolerance = 1e-4) + expect_equal(link_function(m3[[1]])(0.2), log(0.2), tolerance = 1e-4) +}) + +test_that("fixest_multi: get_data", { + # expect_identical(nrow(get_data(m1, verbose = FALSE)), 38325L) + # expect_identical(colnames(get_data(m1, verbose = FALSE)), c("Euros", "dist_km", "Origin", "Destination", "Product")) + # expect_identical(nrow(get_data(m2, verbose = FALSE)), 38325L) + # expect_identical(colnames(get_data(m2, verbose = FALSE)), c("Euros", "dist_km", "Origin", "Destination", "Product")) + # + # # old bug: m4 uses a complex formula and we need to extract all relevant + # # variables in order to compute predictions. + # nd <- get_data(m4, verbose = FALSE) + # tmp <- predict(m4, newdata = nd) + # expect_type(tmp, "double") + # expect_length(tmp, nrow(iris)) +}) + +if (skip_if_not_or_load_if_installed("parameters")) { + # test_that("fixest_multi: get_df", { + # expect_equal(get_df(m1, type = "residual"), 38290, ignore_attr = TRUE) + # expect_equal(get_df(m1, type = "normal"), Inf, ignore_attr = TRUE) + # ## TODO: check if statistic is z or t for this model + # expect_equal(get_df(m1, type = "wald"), 14, ignore_attr = TRUE) + # }) +} + +test_that("fixest_multi: find_formula", { + expect_length(find_formula(m1)[[1]], 2) + expect_equal( + find_formula(m1)[[2]], + list( + conditional = as.formula("Euros ~ log(dist_km)"), + cluster = as.formula("~Origin + Destination + Product") + ), + ignore_attr = TRUE + ) + expect_length(find_formula(m2)[[2]], 2) + expect_equal( + find_formula(m2)[[2]], + list( + conditional = as.formula("log1p(Euros) ~ log(dist_km)"), + cluster = as.formula("~Origin + Destination + Product") + ), + ignore_attr = TRUE + ) +}) + +test_that("fixest_multi: find_terms", { + expect_identical( + find_terms(m1)[[2]], + list(response = "Euros", conditional = "log(dist_km)", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_terms(m1, flatten = TRUE)[[2]], + c("Euros", "log(dist_km)", "Origin", "Destination", "Product") + ) + expect_identical( + find_terms(m2)[[2]], + list(response = "log1p(Euros)", conditional = "log(dist_km)", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_terms(m2, flatten = TRUE)[[2]], + c("log1p(Euros)", "log(dist_km)", "Origin", "Destination", "Product") + ) +}) + + +test_that("fixest_multi: find_variables", { + expect_identical( + find_variables(m1)[[2]], + list(response = "Euros", conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_variables(m1, flatten = TRUE)[[2]], + c("Euros", "dist_km", "Origin", "Destination", "Product") + ) + expect_identical( + find_variables(m2)[[2]], + list(response = "Euros", conditional = "dist_km", cluster = c("Origin", "Destination", "Product")) + ) + expect_identical( + find_variables(m1, flatten = TRUE)[[2]], + c("Euros", "dist_km", "Origin", "Destination", "Product") + ) +}) + + +test_that("fixest_multi: n_obs", { + expect_identical(n_obs(m1)[[1]], 38325L) + expect_identical(n_obs(m2)[[1]], 38325L) +}) + +test_that("fixest_multi: find_parameters", { + expect_identical( + find_parameters(m1)[[1]], + list(conditional = "log(dist_km)") + ) + expect_equal( + get_parameters(m1)[[2]], + data.frame( + Parameter = "log(dist_km)", + Estimate = -1.52774702640008, + row.names = NULL, + stringsAsFactors = FALSE + ), + tolerance = 1e-4 + ) + expect_identical( + find_parameters(m2)[[1]], + list(conditional = "log(dist_km)") + ) + expect_equal( + get_parameters(m2)[[2]], + data.frame( + Parameter = "log(dist_km)", + Estimate = -2.16843021944503, + row.names = NULL, + stringsAsFactors = FALSE + ), + tolerance = 1e-4 + ) +}) + +test_that("fixest_multi: is_multivariate", { + expect_false(is_multivariate(m1)[[1]]) +}) + +test_that("fixest_multi: find_statistic", { + expect_identical(find_statistic(m1)[[1]], "z-statistic") + expect_identical(find_statistic(m2)[[1]], "t-statistic") +}) + +test_that("fixest_multi: get_statistic", { + stat <- get_statistic(m1)[[2]] + expect_equal(stat$Statistic, -13.212695, tolerance = 1e-3) + stat <- get_statistic(m2)[[2]] + expect_equal(stat$Statistic, -14.065336, tolerance = 1e-3) +}) + +test_that("fixest_multi: get_predicted", { + # pred <- get_predicted(m1) + # expect_s3_class(pred, "get_predicted") + # expect_length(pred, nrow(trade)) + # a <- get_predicted(m1) + # b <- get_predicted(m1, type = "response", predict = NULL) + # expect_equal(a, b, tolerance = 1e-5) + # a <- get_predicted(m1, predict = "link") + # b <- get_predicted(m1, type = "link", predict = NULL) + # expect_equal(a, b, tolerance = 1e-5) + # # these used to raise warnings + # expect_warning(get_predicted(m1, ci = 0.4), NA) + # expect_warning(get_predicted(m1, predict = NULL, type = "link"), NA) +}) + +test_that("fixest_multi: get_data works when model data has name of reserved words", { + ## NOTE check back every now and then and see if tests still work + # skip("works interactively") + # rep <- data.frame(Y = runif(100) > 0.5, X = rnorm(100)) + # m <- feglm(Y ~ X, data = rep, family = binomial) + # out <- get_data(m) + # expect_s3_class(out, "data.frame") + # expect_equal( + # head(out), + # structure( + # list( + # Y = c(TRUE, TRUE, TRUE, TRUE, FALSE, FALSE), + # X = c( + # -1.37601434046896, -0.0340090992175856, 0.418083058388383, + # -0.51688491498936, -1.30634551903768, -0.858343109785566 + # ) + # ), + # is_subset = FALSE, row.names = c(NA, 6L), class = "data.frame" + # ), + # ignore_attr = TRUE, + # tolerance = 1e-3 + # ) +}) + + +test_that("fixest_multi: find_variables with interaction", { + mod <- suppressMessages(feols(c(mpg, drat) ~ 0 | carb | vs:cyl ~ am:cyl, data = mtcars)) + expect_equal( + find_variables(mod)[[1]], + list( + response = "mpg", conditional = "vs", cluster = "carb", + instruments = c("am", "cyl"), endogenous = c("vs", "cyl") + ), + ignore_attr = TRUE + ) + + # used to produce a warning + mod <- feols(c(mpg, drat) ~ 0 | carb | vs:cyl ~ am:cyl, data = mtcars) + expect_warning(find_variables(mod)[[1]], NA) +}) + + +test_that("fixest_multi: find_predictors with i(f1, i.f2) interaction", { + aq <- airquality + aq$week <- aq$Day %/% 7 + 1 + + mod <- feols(c(Ozone, Temp) ~ i(Month, i.week), aq, notes = FALSE) + expect_equal( + find_predictors(mod)[[1]], + list( + conditional = c("Month", "week") + ), + ignore_attr = TRUE + ) +}) + + From aab3c6f39973f58b6c39eb7e42d2232648414837 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 18 Mar 2023 11:16:51 +0100 Subject: [PATCH 3/3] update is_model et al --- R/is_model.R | 2 +- R/is_model_supported.R | 3 +- README.md | 157 +++++++++++++++++++++-------------------- 3 files changed, 82 insertions(+), 80 deletions(-) diff --git a/R/is_model.R b/R/is_model.R index 28f401f96e..266bebc356 100644 --- a/R/is_model.R +++ b/R/is_model.R @@ -81,7 +81,7 @@ is_regression_model <- function(x) { # f -------------------- "feglm", "feis", "felm", "fitdistr", "fixest", "flexmix", - "flexsurvreg", "flac", "flic", + "flexsurvreg", "flac", "flic", "fixest_multi", # g -------------------- "gam", "Gam", "GAMBoost", "gamlr", "gamlss", "gamm", "gamm4", diff --git a/R/is_model_supported.R b/R/is_model_supported.R index db3988f878..4bfc3f2771 100644 --- a/R/is_model_supported.R +++ b/R/is_model_supported.R @@ -65,7 +65,8 @@ supported_models <- function() { "eglm", "elm", "epi.2by2", "ergm", # f ---------------------------- - "feis", "felm", "feglm", "fitdistr", "fixest", "flexsurvreg", "flac", "flic", + "feis", "felm", "feglm", "fitdistr", "fixest", "flexsurvreg", "flac", + "flic", "fixest_multi", # g ---------------------------- "gam", "Gam", "gamlss", "gamm", "gamm4", "garch", "gbm", "gee", "geeglm", diff --git a/README.md b/README.md index d6efc8156c..b188ca1f0c 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,7 @@ email or also file an issue. ## List of Supported Models by Class -Currently, 222 model classes are supported. +Currently, 223 model classes are supported. ``` r supported_models() @@ -321,83 +321,84 @@ supported_models() #> [63] "ergm" "feglm" #> [65] "feis" "felm" #> [67] "fitdistr" "fixest" -#> [69] "flac" "flexsurvreg" -#> [71] "flic" "gam" -#> [73] "Gam" "gamlss" -#> [75] "gamm" "gamm4" -#> [77] "garch" "gbm" -#> [79] "gee" "geeglm" -#> [81] "glht" "glimML" -#> [83] "glm" "Glm" -#> [85] "glmm" "glmmadmb" -#> [87] "glmmPQL" "glmmTMB" -#> [89] "glmrob" "glmRob" -#> [91] "glmx" "gls" -#> [93] "gmnl" "hglm" -#> [95] "HLfit" "htest" -#> [97] "hurdle" "iv_robust" -#> [99] "ivFixed" "ivprobit" -#> [101] "ivreg" "lavaan" -#> [103] "lm" "lm_robust" -#> [105] "lme" "lmerMod" -#> [107] "lmerModLmerTest" "lmodel2" -#> [109] "lmrob" "lmRob" -#> [111] "logistf" "logitmfx" -#> [113] "logitor" "logitr" -#> [115] "LORgee" "lqm" -#> [117] "lqmm" "lrm" -#> [119] "manova" "MANOVA" -#> [121] "marginaleffects" "marginaleffects.summary" -#> [123] "margins" "maxLik" -#> [125] "mblogit" "mclogit" -#> [127] "mcmc" "mcmc.list" -#> [129] "MCMCglmm" "mcp1" -#> [131] "mcp12" "mcp2" -#> [133] "med1way" "mediate" -#> [135] "merMod" "merModList" -#> [137] "meta_bma" "meta_fixed" -#> [139] "meta_random" "metaplus" -#> [141] "mhurdle" "mipo" -#> [143] "mira" "mixed" -#> [145] "MixMod" "mixor" -#> [147] "mjoint" "mle" -#> [149] "mle2" "mlm" -#> [151] "mlogit" "mmclogit" -#> [153] "mmlogit" "mmrm" -#> [155] "mmrm_fit" "mmrm_tmb" -#> [157] "model_fit" "multinom" -#> [159] "mvord" "negbinirr" -#> [161] "negbinmfx" "ols" -#> [163] "onesampb" "orm" -#> [165] "pgmm" "plm" -#> [167] "PMCMR" "poissonirr" -#> [169] "poissonmfx" "polr" -#> [171] "probitmfx" "psm" -#> [173] "Rchoice" "ridgelm" -#> [175] "riskRegression" "rjags" -#> [177] "rlm" "rlmerMod" -#> [179] "RM" "rma" -#> [181] "rma.uni" "robmixglm" -#> [183] "robtab" "rq" -#> [185] "rqs" "rqss" -#> [187] "rvar" "Sarlm" -#> [189] "scam" "selection" -#> [191] "sem" "SemiParBIV" -#> [193] "semLm" "semLme" -#> [195] "slm" "speedglm" -#> [197] "speedlm" "stanfit" -#> [199] "stanmvreg" "stanreg" -#> [201] "summary.lm" "survfit" -#> [203] "survreg" "svy_vglm" -#> [205] "svychisq" "svyglm" -#> [207] "svyolr" "t1way" -#> [209] "tobit" "trimcibt" -#> [211] "truncreg" "vgam" -#> [213] "vglm" "wbgee" -#> [215] "wblm" "wbm" -#> [217] "wmcpAKP" "yuen" -#> [219] "yuend" "zcpglm" -#> [221] "zeroinfl" "zerotrunc" +#> [69] "fixest_multi" "flac" +#> [71] "flexsurvreg" "flic" +#> [73] "gam" "Gam" +#> [75] "gamlss" "gamm" +#> [77] "gamm4" "garch" +#> [79] "gbm" "gee" +#> [81] "geeglm" "glht" +#> [83] "glimML" "glm" +#> [85] "Glm" "glmm" +#> [87] "glmmadmb" "glmmPQL" +#> [89] "glmmTMB" "glmrob" +#> [91] "glmRob" "glmx" +#> [93] "gls" "gmnl" +#> [95] "hglm" "HLfit" +#> [97] "htest" "hurdle" +#> [99] "iv_robust" "ivFixed" +#> [101] "ivprobit" "ivreg" +#> [103] "lavaan" "lm" +#> [105] "lm_robust" "lme" +#> [107] "lmerMod" "lmerModLmerTest" +#> [109] "lmodel2" "lmrob" +#> [111] "lmRob" "logistf" +#> [113] "logitmfx" "logitor" +#> [115] "logitr" "LORgee" +#> [117] "lqm" "lqmm" +#> [119] "lrm" "manova" +#> [121] "MANOVA" "marginaleffects" +#> [123] "marginaleffects.summary" "margins" +#> [125] "maxLik" "mblogit" +#> [127] "mclogit" "mcmc" +#> [129] "mcmc.list" "MCMCglmm" +#> [131] "mcp1" "mcp12" +#> [133] "mcp2" "med1way" +#> [135] "mediate" "merMod" +#> [137] "merModList" "meta_bma" +#> [139] "meta_fixed" "meta_random" +#> [141] "metaplus" "mhurdle" +#> [143] "mipo" "mira" +#> [145] "mixed" "MixMod" +#> [147] "mixor" "mjoint" +#> [149] "mle" "mle2" +#> [151] "mlm" "mlogit" +#> [153] "mmclogit" "mmlogit" +#> [155] "mmrm" "mmrm_fit" +#> [157] "mmrm_tmb" "model_fit" +#> [159] "multinom" "mvord" +#> [161] "negbinirr" "negbinmfx" +#> [163] "ols" "onesampb" +#> [165] "orm" "pgmm" +#> [167] "plm" "PMCMR" +#> [169] "poissonirr" "poissonmfx" +#> [171] "polr" "probitmfx" +#> [173] "psm" "Rchoice" +#> [175] "ridgelm" "riskRegression" +#> [177] "rjags" "rlm" +#> [179] "rlmerMod" "RM" +#> [181] "rma" "rma.uni" +#> [183] "robmixglm" "robtab" +#> [185] "rq" "rqs" +#> [187] "rqss" "rvar" +#> [189] "Sarlm" "scam" +#> [191] "selection" "sem" +#> [193] "SemiParBIV" "semLm" +#> [195] "semLme" "slm" +#> [197] "speedglm" "speedlm" +#> [199] "stanfit" "stanmvreg" +#> [201] "stanreg" "summary.lm" +#> [203] "survfit" "survreg" +#> [205] "svy_vglm" "svychisq" +#> [207] "svyglm" "svyolr" +#> [209] "t1way" "tobit" +#> [211] "trimcibt" "truncreg" +#> [213] "vgam" "vglm" +#> [215] "wbgee" "wblm" +#> [217] "wbm" "wmcpAKP" +#> [219] "yuen" "yuend" +#> [221] "zcpglm" "zeroinfl" +#> [223] "zerotrunc" ``` - **Didn’t find a model?** [File an