Skip to content

Commit

Permalink
Merge pull request #33 from vincentarelbundock/master
Browse files Browse the repository at this point in the history
bugfix prediction_glm.R newdata argument
  • Loading branch information
leeper committed Apr 6, 2019
2 parents 1645439 + de28a1e commit 96071dd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
5 changes: 4 additions & 1 deletion DESCRIPTION
Expand Up @@ -10,7 +10,10 @@ Authors@R: c(person("Thomas J.", "Leeper",
email = "thosjleeper@gmail.com",
comment = c(ORCID = "0000-0003-4097-6326")),
person("Carl", "Ganz", role = "ctb",
email = "carlganz@ucla.edu")
email = "carlganz@ucla.edu"),
person("Vincent", "Arel-Bundock", role = "ctb",
email = "vincent.arel-bundock@umontreal.ca",
comment = c(ORCID = "0000-0003-2042-7063"))
)
URL: https://github.com/leeper/prediction
BugReports: https://github.com/leeper/prediction/issues
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
@@ -1,3 +1,7 @@
# prediction 0.3.13

* Fixed a bug in `prediction_glm` with the `data` argument (Issue #32).

# prediction 0.3.12

* Remove mnlogit dependency, as it has been removed from CRAN.
Expand Down
4 changes: 2 additions & 2 deletions R/prediction_glm.R
Expand Up @@ -51,7 +51,7 @@ function(model,
if (type == "link") {
means_for_prediction <- colMeans(model_mat)
} else if (type == "response") {
predictions_link <- predict(model, data = data, type = "link", se.fit = FALSE, ...)
predictions_link <- predict(model, newdata = data, type = "link", se.fit = FALSE, ...)
means_for_prediction <- colMeans(model$family$mu.eta(predictions_link) * model_mat)
}
J <- matrix(means_for_prediction, nrow = 1L)
Expand All @@ -64,7 +64,7 @@ function(model,
if (type == "link") {
means_for_prediction <- colMeans(model_mat)
} else if (type == "response") {
predictions_link <- predict(model, data = one, type = "link", se.fit = FALSE, ...)
predictions_link <- predict(model, newdata = one, type = "link", se.fit = FALSE, ...)
means_for_prediction <- colMeans(model$family$mu.eta(predictions_link) * model_mat)
}
means_for_prediction
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/tests-core.R
Expand Up @@ -19,6 +19,16 @@ test_that("Test prediction()", {
label = "prediction() matches predict() (GLM)")
})

test_that("Test prediction(data = )", {
m <- lm(mpg ~ cyl + wt, data = mtcars)
p1 <- prediction(m, data = data.frame(cyl = 4, wt = 3.9))
expect_true(inherits(p1, "data.frame"), label = "prediction(lm(~), data = data.frame()) works")

m <- glm(mpg ~ cyl + wt, data = mtcars)
p1 <- prediction(m, data = data.frame(cyl = 4, wt = 3.9))
expect_true(inherits(p1, "data.frame"), label = "prediction(glm(~), data = data.frame()) works")
})

test_that("Test prediction(at = )", {
m <- lm(mpg ~ cyl, data = mtcars)
p1 <- prediction(m, at = list(cyl = 4))
Expand Down

0 comments on commit 96071dd

Please sign in to comment.