From d6b86c1c63d45406110f8b05296cdbc8afc0ef2b Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Sun, 23 Mar 2025 23:01:40 +0100 Subject: [PATCH 1/2] Fix wrong order in decision boundary plot --- R/commonMachineLearningClassification.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index 74496c54..f21c3167 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -564,13 +564,13 @@ act.fct = jaspResults[["actfct"]]$object, linear.output = FALSE ) - predictions <- as.factor(max.col(predict(fit, newdata = grid))) - levels(predictions) <- unique(dataset[, options[["target"]]]) + probabilities <- predict(fit, newdata = grid) + predictions <- levels(dataset[, options[["target"]]])[apply(probabilities, 1, which.max)] } else if (type == "rpart") { classificationResult <- jaspResults[["classificationResult"]]$object fit <- rpart::rpart(formula, data = dataset, method = "class", control = rpart::rpart.control(minsplit = options[["minObservationsForSplit"]], minbucket = options[["minObservationsInNode"]], maxdepth = options[["interactionDepth"]], cp = classificationResult[["penalty"]])) - predictions <- as.factor(max.col(predict(fit, newdata = grid))) - levels(predictions) <- unique(dataset[, options[["target"]]]) + probabilities <- predict(fit, newdata = grid) + predictions <- colnames(probabilities)[apply(probabilities, 1, which.max)] } else if (type == "svm") { classificationResult <- jaspResults[["classificationResult"]]$object fit <- e1071::svm(formula, @@ -580,8 +580,8 @@ predictions <- predict(fit, newdata = grid) } else if (type == "naivebayes") { fit <- e1071::naiveBayes(formula, data = dataset, laplace = options[["smoothingParameter"]]) - predictions <- as.factor(max.col(predict(fit, newdata = grid, type = "raw"))) - levels(predictions) <- unique(dataset[, options[["target"]]]) + probabilities <- predict(fit, newdata = grid, type = "raw") + predictions <- colnames(probabilities)[apply(probabilities, 1, which.max)] } else if (type == "logistic") { if (classificationResult[["family"]] == "binomial") { fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]])) From bde825d9e7beb73a72677eafa894794d823eb683 Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Sun, 23 Mar 2025 23:13:11 +0100 Subject: [PATCH 2/2] Also fix logistic --- R/commonMachineLearningClassification.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index f21c3167..54568626 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -585,8 +585,8 @@ } else if (type == "logistic") { if (classificationResult[["family"]] == "binomial") { fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]])) - predictions <- as.factor(round(predict(fit, grid, type = "response"), 0)) - levels(predictions) <- unique(dataset[, options[["target"]]]) + probabilities <- predict(fit, grid, type = "response") + predictions <- levels(dataset[, options[["target"]]])[round(probabilities, 0) + 1] } else { fit <- VGAM::vglm(formula, data = dataset, family = VGAM::multinomial()) logodds <- predict(fit, newdata = grid)