Skip to content

Commit

Permalink
mtscr_score() final
Browse files Browse the repository at this point in the history
  • Loading branch information
jakub-jedrusiak committed May 1, 2023
1 parent ba41013 commit 69676ea
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 10 deletions.
62 changes: 54 additions & 8 deletions R/mtscr_score.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,29 @@
#'
#' @inheritParams mtscr_prepare
#' @inheritParams mtscr_model
#' @param summarise_for Character, whether to get creativity scores for each person, each item or both.
#' Can be `"person"`, `"item"` or `"both"`.
#' @param append Logical, whether to add the scores columnd to the original data frame ('TRUE') or
#' return a new data frame with only the scores ('FALSE'). Additional columns depend on `summarise_for`.
#'
#' @return A data frame with the creativity scores in the columns `.all_max` and `.all_top2`.
#' @return A data frame with creativity scores. If `summarise_for` is `"person"`, the data frame
#' will have one row per person. If `summarise_for` is `"item"`, the data frame will have one
#' row per item. If `summarise_for` is `"both"`, the data frame will have one row per person-item.
#' If `append` is `TRUE`, the original data frame will be returned with the scores columns added.
#' By default (`append = FALSE`), only the scores columns and id and/or item columns
#' (depending on `summarise_for` value) will be returned.
#'
#' @export
#'
#' @examples
#' data("mtscr_creativity", package = "mtscr")
#' mtscr_score(mtscr_creativity, id, item, SemDis_MEAN)
mtscr_score <- function(df, id_column, item_column, score_column, model_type = c("all_max", "all_top2")) {
#' mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, summarise_for = "person") # one score for person
mtscr_score <- function(df, id_column, item_column, score_column, model_type = c("all_max", "all_top2"), summarise_for = "person", append = FALSE) {
id_column <- rlang::ensym(id_column)
item_column <- rlang::ensym(item_column)
score_column <- rlang::ensym(score_column)
df_original <- df

if (!all(model_type %in% c("all_max", "all_top2"))) {
cli::cli_abort(
Expand All @@ -27,18 +38,53 @@ mtscr_score <- function(df, id_column, item_column, score_column, model_type = c
# prepare
df <- mtscr_prepare(df, !!id_column, !!item_column, !!score_column, minimal = FALSE)
model <- mtscr_model(df, !!id_column, !!item_column, !!score_column, model_type = model_type, prepared = TRUE)
if (!is.list(model)) {
model <- list(model_type = model) # make it a named list for compatibility when predicting
}

# score
if ("all_max" %in% model_type) {
if (length(model_type) > 1) {
df$.all_max <- predict(model[["all_max"]], df)
}
if ("all_top2" %in% model_type) {
df$.all_top2 <- predict(model[["all_top2"]], df)
} else if (model_type == "all_max") {
df$.all_max <- predict(model, df)
} else if (model_type == "all_top2") {
df$.all_top2 <- predict(model, df)
}

df <- dplyr::select(df, -.z_score, -.ordering, -.ordering_0, -.ordering_top2_0, -.max_ind, -.top2_ind)

# summarise
if (summarise_for == "person") {
groups <- rlang::as_name(id_column)
} else if (summarise_for == "item") {
groups <- rlang::as_name(item_column)
} else if (summarise_for == "both") {
groups <- c(rlang::as_name(id_column), rlang::as_name(item_column))
} else {
cli::cli_abort(
c(
"{.arg summarise_for} must be a subset of {c('person', 'item', 'both')}.",
"x" = "{.var summarise_for} is not a valid value."
)
)
}

args <- list()
if ("all_max" %in% model_type) {
args[[".all_max"]] <- rlang::expr(max(.all_max))
}
if ("all_top2" %in% model_type) {
args[[".all_top2"]] <- rlang::expr(max(.all_top2))
}

df <- df |>
dplyr::summarise(
!!!args,
.by = dplyr::all_of(groups)
)

# append
if (append) {
df <- dplyr::left_join(df_original, df, by = groups)
}

return(df)
}
18 changes: 16 additions & 2 deletions man/mtscr_score.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions tests/testthat/test-mtscr_score.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
data("mtscr_creativity", package = "mtscr")

# Test that `model_type` argument throws an error when invalid values are provided
test_that("model_type argument throws an error when invalid values are provided", {
# call function with model_type = "invalid"
expect_error(mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = "invalid"))

# call function with model_type = c("all_max", "invalid")
expect_error(mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = c("all_max", "invalid")))
})

# Test that `summarise_for` argument works as expected
# person
test_that("summarise_for argument works as expected for person", {
# call function with summarise_for = "person"
res_person <- mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = "all_max", summarise_for = "person")

# check that res_person has the expected number of rows and columns
expect_equal(ncol(res_person), 2)

# check that res_person has the expected column names
expect_equal(colnames(res_person), c("id", ".all_max"))
})

# item
test_that("summarise_for argument works as expected for item", {
# call function with summarise_for = "item"
res_item <- mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = "all_top2", summarise_for = "item")

# check that res_item has the expected number of rows and columns
expect_equal(ncol(res_item), 2)

# check that res_item has the expected column names
expect_equal(colnames(res_item), c("item", ".all_top2"))
})

# both
test_that("summarise_for argument works as expected for both", {
# call function with summarise_for = "both"
res_both <- mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = c("all_max", "all_top2"), summarise_for = "both")

# check that res_both has the expected number of rows and columns
expect_equal(ncol(res_both), 4)

# check that res_both has the expected column names
expect_equal(colnames(res_both), c("id", "item", ".all_max", ".all_top2"))
})

# invalid
test_that("summarise_for argument throws an error when invalid values are provided", {
# call function with summarise_for = "invalid"
expect_error(mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = "all_max", summarise_for = "invalid"))

# call function with summarise_for = c("person", "invalid")
expect_error(mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, model_type = "all_max", summarise_for = c("person", "invalid")))
})

# Test that `append` argument works as expected
test_that("append argument works as expected", {
# call function with append = TRUE
res_append <- mtscr_score(mtscr_creativity, id, item, SemDis_MEAN, append = TRUE)

# check that res_append has the expected number of rows and columns
expect_equal(ncol(res_append), ncol(mtscr_creativity) + 2)

# check that res_append has the expected column names
expect_named(res_append, c(names(mtscr_creativity), ".all_max", ".all_top2"), ignore.order = TRUE)
})

0 comments on commit 69676ea

Please sign in to comment.