diff --git a/R/glm_tools.R b/R/glm_tools.R index f7b21314..f1b1fac5 100644 --- a/R/glm_tools.R +++ b/R/glm_tools.R @@ -53,49 +53,57 @@ fit_glm <- function(target, mm, nb_vals) { assertthat::assert_that(nrow(mm) > 0) engine <- options()[["mixvlmc.predictive"]] assertthat::assert_that(engine %in% c("glm", "multinom")) - if (engine == "glm") { - if (nb_vals == 2) { - if (ncol(mm) > 0) { - suppressWarnings(result <- - stats::glm(target ~ ., - data = mm, family = stats::binomial(), - method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE, - model = FALSE - )) + target_dist <- table(target) + target_dist <- target_dist[target_dist > 0] + if (length(target_dist) == 1) { + ## degenerate case + constant_model(target, mm, nb_vals) + } else { + if (engine == "glm") { + if (nb_vals == 2) { + if (ncol(mm) > 0) { + suppressWarnings(result <- + stats::glm(target ~ ., + data = mm, family = stats::binomial(), + method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE, + model = FALSE + )) + } else { + suppressWarnings(result <- + stats::glm(target ~ 1, + family = stats::binomial(), + method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE, + model = FALSE + )) + } } else { - suppressWarnings(result <- - stats::glm(target ~ 1, - family = stats::binomial(), - method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE, - model = FALSE - )) + if (ncol(mm) > 0) { + suppressWarnings(result <- + VGAM::vglm(target ~ ., + data = mm, family = VGAM::multinomial(), + x.arg = FALSE, y.arg = FALSE, model = FALSE + )) + } else { + suppressWarnings(result <- + VGAM::vglm(target ~ 1, + data = mm, family = VGAM::multinomial(), + x.arg = FALSE, y.arg = FALSE, model = FALSE + )) + } } - } else { + result + } else if (engine == "multinom") { if (ncol(mm) > 0) { - suppressWarnings(result <- - VGAM::vglm(target ~ ., - data = mm, family = VGAM::multinomial(), - x.arg = FALSE, y.arg = FALSE, model = FALSE - )) + result <- nnet::multinom(target ~ ., data = mm, trace = FALSE) } else { - suppressWarnings(result <- - VGAM::vglm(target ~ 1, - data = mm, family = VGAM::multinomial(), - x.arg = FALSE, y.arg = FALSE, model = FALSE - )) + result <- nnet::multinom(target ~ 1, trace = FALSE) } - } - result - } else if (engine == "multinom") { - if (ncol(mm) > 0) { - result <- nnet::multinom(target ~ ., data = mm, trace = FALSE) + result$rank <- length(result$wts) + result } else { - result <- nnet::multinom(target ~ 1, trace = FALSE) + ## should not happen + NULL } - result$rank <- length(result$wts) - result - } else { - NULL } }