Skip to content

Commit

Permalink
Use the constant model when needed.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabrice-rossi committed Jul 7, 2022
1 parent 5a4b3a3 commit 93ea134
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions R/glm_tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,49 +53,57 @@ fit_glm <- function(target, mm, nb_vals) {
assertthat::assert_that(nrow(mm) > 0)
engine <- options()[["mixvlmc.predictive"]]
assertthat::assert_that(engine %in% c("glm", "multinom"))
if (engine == "glm") {
if (nb_vals == 2) {
if (ncol(mm) > 0) {
suppressWarnings(result <-
stats::glm(target ~ .,
data = mm, family = stats::binomial(),
method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE,
model = FALSE
))
target_dist <- table(target)
target_dist <- target_dist[target_dist > 0]
if (length(target_dist) == 1) {
## degenerate case
constant_model(target, mm, nb_vals)
} else {
if (engine == "glm") {
if (nb_vals == 2) {
if (ncol(mm) > 0) {
suppressWarnings(result <-
stats::glm(target ~ .,
data = mm, family = stats::binomial(),
method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE,
model = FALSE
))
} else {
suppressWarnings(result <-
stats::glm(target ~ 1,
family = stats::binomial(),
method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE,
model = FALSE
))
}
} else {
suppressWarnings(result <-
stats::glm(target ~ 1,
family = stats::binomial(),
method = spaMM::spaMM_glm.fit, x = FALSE, y = FALSE,
model = FALSE
))
if (ncol(mm) > 0) {
suppressWarnings(result <-
VGAM::vglm(target ~ .,
data = mm, family = VGAM::multinomial(),
x.arg = FALSE, y.arg = FALSE, model = FALSE
))
} else {
suppressWarnings(result <-
VGAM::vglm(target ~ 1,
data = mm, family = VGAM::multinomial(),
x.arg = FALSE, y.arg = FALSE, model = FALSE
))
}
}
} else {
result
} else if (engine == "multinom") {
if (ncol(mm) > 0) {
suppressWarnings(result <-
VGAM::vglm(target ~ .,
data = mm, family = VGAM::multinomial(),
x.arg = FALSE, y.arg = FALSE, model = FALSE
))
result <- nnet::multinom(target ~ ., data = mm, trace = FALSE)
} else {
suppressWarnings(result <-
VGAM::vglm(target ~ 1,
data = mm, family = VGAM::multinomial(),
x.arg = FALSE, y.arg = FALSE, model = FALSE
))
result <- nnet::multinom(target ~ 1, trace = FALSE)
}
}
result
} else if (engine == "multinom") {
if (ncol(mm) > 0) {
result <- nnet::multinom(target ~ ., data = mm, trace = FALSE)
result$rank <- length(result$wts)
result
} else {
result <- nnet::multinom(target ~ 1, trace = FALSE)
## should not happen
NULL
}
result$rank <- length(result$wts)
result
} else {
NULL
}
}

Expand Down

0 comments on commit 93ea134

Please sign in to comment.