Skip to content

Commit

Permalink
Adds tests and fixes interval fit for dfs
Browse files Browse the repository at this point in the history
  • Loading branch information
edgararuiz committed Oct 6, 2018
1 parent 386f85e commit 7742d0f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
22 changes: 9 additions & 13 deletions R/predict-interval.R
Expand Up @@ -36,27 +36,23 @@ tidypredict_interval.glm <- function(model, interval = 0.95) {

#' @export
`tidypredict_interval.data.frame` <- function(model, interval = 0.95) {
model <- model %>%
mutate_if(is.factor, as.character) %>%
as.tibble()

model_type <- model %>%
filter(.data$labels == "model") %>%
pull(.data$vals)

model_type <- model[model$labels == "model", "vals"][[1]]

assigned <- 0

if (model_type == "lm") {
assigned <- 1
te_interval_lm(model)
ret <- te_interval_lm(model, interval = interval)
}

if (model_type == "glm") {
assigned <- 1
te_interval_glm(model)
ret <- te_interval_glm(model, interval = interval)
}

if (assigned == 0) {
stop("Model not recognized")
} else {
ret
}
}
8 changes: 8 additions & 0 deletions tests/testthat/test-tester.R
@@ -0,0 +1,8 @@
context("test-tester")

test_that("Tester returns warning", {
t <- tidypredict_test(
lm(mpg ~ wt, offset = am, data = mtcars),
threshold = 0)
expect_true(t$alert)
})
16 changes: 14 additions & 2 deletions tests/testthat/test_lm.R
@@ -1,7 +1,7 @@
context("lm")

df <- mtcars %>%
mutate(cyl = paste0("cyl", cyl))
df <- mtcars
df$cyl <- paste0("cyl", df$cyl)

has_alert <- function(model) {
# test1: check for any predictions are above the threshold
Expand All @@ -26,3 +26,15 @@ test_that("Predictions within threshold and parsed model results are equal", {
expect_false(has_alert(lm(mpg ~ wt + disp * cyl, data = df)))
expect_false(has_alert(lm(mpg ~ (wt + disp) * cyl, data = df)))
})

test_that("Intervals are within the threshold", {
pm <- parse_model(lm(mpg ~ am + wt, data = mtcars))
t <- tidypredict_interval(pm)
expected <- rlang::expr(2.0452296421327 * sqrt(-0.176776695296637 * -0.176776695296637 *
9.59723093870482 + (0.1462244129664 + (am) * (-0.359937016532678)) *
(0.1462244129664 + (am) * (-0.359937016532678)) * 9.59723093870482 +
(-0.958962404795433 + (am) * (0.345504476304964) + (wt) *
(0.25444128099978)) * (-0.958962404795433 + (am) * (0.345504476304964) +
(wt) * (0.25444128099978)) * 9.59723093870482 + 9.59723093870482))
expect_equal(t, expected)
})

0 comments on commit 7742d0f

Please sign in to comment.