Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review scoringutils 2.0.0 #791

Merged
merged 11 commits into from
May 19, 2024
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ Authors@R: c(
family = "Bosse",
role = c("aut", "cre"),
email = "nikosbosse@gmail.com",
comment = c(ORCID = "https://orcid.org/0000-0002-7750-5280")),
comment = c(ORCID = "0000-0002-7750-5280")),
person(given = "Sam Abbott",
role = c("aut"),
email = "contact@samabbott.co.uk",
comment = c(ORCID = "0000-0001-8057-8037")),
person(given = "Hugo",
family = "Gruson",
role = c("aut"),
email = "hugo.gruson@lshtm.ac.uk",
comment = c(ORCID = "https://orcid.org/0000-0002-4094-1476")),
email = "hugo.gruson+R@normalesup.org",
comment = c(ORCID = "0000-0002-4094-1476")),
person(given = "Johannes Bracher",
role = c("ctb"),
email = "johannes.bracher@kit.edu",
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_function)
importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_matrix)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_subset)
Expand All @@ -104,6 +105,7 @@ importFrom(checkmate,check_set_equal)
importFrom(checkmate,check_vector)
importFrom(checkmate,test_factor)
importFrom(checkmate,test_list)
importFrom(checkmate,test_names)
importFrom(checkmate,test_numeric)
importFrom(checkmate,test_subset)
importFrom(cli,cli_abort)
Expand Down
9 changes: 3 additions & 6 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ check_columns_present <- function(data, columns) {
#' @keywords internal_input_check
test_columns_present <- function(data, columns) {
check <- check_columns_present(data, columns)
return(is.logical(check))
return(isTRUE(check))
}

#' Test whether column names are NOT present in a data.frame
Expand All @@ -175,11 +175,8 @@ test_columns_present <- function(data, columns) {
#' more columns are present, the function returns FALSE.
#' @inheritParams document_check_functions
#' @return Returns TRUE if none of the columns are present and FALSE otherwise
#' @importFrom checkmate test_names
#' @keywords internal_input_check
test_columns_not_present <- function(data, columns) {
if (any(columns %in% colnames(data))) {
return(FALSE)
} else {
return(TRUE)
}
test_names(colnames(data), disjunct.from = columns)
}
4 changes: 2 additions & 2 deletions R/check-inputs-scoring-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' N (number of columns) the number of samples per forecast.
#' If `observed` is just a single number, then predicted values can just be a
#' vector of size N.
#' @importFrom checkmate assert assert_numeric check_matrix
#' @importFrom checkmate assert assert_numeric check_matrix assert_matrix
#' @inherit document_assert_functions params return
#' @keywords internal_input_check
assert_input_sample <- function(observed, predicted) {
Expand All @@ -21,7 +21,7 @@ assert_input_sample <- function(observed, predicted) {
check_matrix(predicted, mode = "numeric", nrows = n_obs)
)
} else {
assert(check_matrix(predicted, mode = "numeric", nrows = n_obs))
assert_matrix(predicted, mode = "numeric", nrows = n_obs)
}
return(invisible(NULL))
}
Expand Down
2 changes: 1 addition & 1 deletion R/convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ transform_forecasts <- function(forecast,
#nolint start: keyword_quote_linter
cli_abort(
c(
"!" = "If a column 'scale' is present, entries with scale =='natural'
`!` = "If a column 'scale' is present, entries with scale =='natural'
are required for the transformation."
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/correlations.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ get_correlations <- function(scores,
return(correlations[])
}

# helper function to obtain upper triangle of matrix
# helper function to obtain lower triangle of matrix
get_lower_tri <- function(cormat) {
cormat[lower.tri(cormat)] <- NA
return(cormat)
Expand Down
9 changes: 5 additions & 4 deletions R/default-scoring-rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ select_metrics <- function(metrics, select = NULL, exclude = NULL) {

if (is.null(select) && is.null(exclude)) {
return(metrics)
} else if (is.null(select)) {
}
if (is.null(select)) {
assert_subset(exclude, allowed)
select <- allowed[!allowed %in% exclude]
return(metrics[select])
} else {
assert_subset(select, allowed)
return(metrics[select])
}
assert_subset(select, allowed)
return(metrics[select])

}

#' Customises a metric function with additional arguments.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function doesn't do anything extra compared to generic alternatives (e.g., argument checking), I would probably remove it and recommend the alternatives (e.g., purrr::partial())

Copy link
Contributor

@nikosbosse nikosbosse May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd have to add purrr as in import, but agree that using an existing alternative is nice in general.

What we're getting from the current implementation is

  • a slightly clearer function name (customise_metric() vs. partial())
  • the documentation in customise_metric() that explains what you're expected to do in the context of score()
  • potential future flexibility to add functionality like checking input functions etc.

I think I could be persuaded either way 🤷 But we should definitely make a decision before 2.0.0

Expand Down
6 changes: 2 additions & 4 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ print.forecast_binary <- function(x, ...) {
} else {
cli_text(
col_blue(
"Forecast type:"
)
)
cli_text(
"Forecast type: "
),
"{forecast_type}"
)
}
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ forecast_quantile <- example_quantile |>
#> unexpected.

print(forecast_quantile, 2)
#> Forecast type:
#> quantile
#> Forecast type: quantile
#> Forecast unit:
#> location, forecast_date, target_end_date, target_type, model, and horizon
#>
Expand Down
2 changes: 1 addition & 1 deletion man/scoringutils-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 20 additions & 40 deletions tests/testthat/_snaps/print.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
Code
print(dat)
Message
Forecast type:
binary
Forecast type: binary
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -42,8 +41,7 @@
Code
print(dat)
Message
Forecast type:
binary
Forecast type: binary
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -81,8 +79,7 @@
Code
print(dat)
Message
Forecast type:
binary
Forecast type: binary
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -120,8 +117,7 @@
Code
print(dat)
Message
Forecast type:
binary
Forecast type: binary
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -159,8 +155,7 @@
Code
print(dat)
Message
Forecast type:
quantile
Forecast type: quantile
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -199,8 +194,7 @@
Code
print(dat)
Message
Forecast type:
quantile
Forecast type: quantile
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -239,8 +233,7 @@
Code
print(dat)
Message
Forecast type:
quantile
Forecast type: quantile
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -279,8 +272,7 @@
Code
print(dat)
Message
Forecast type:
quantile
Forecast type: quantile
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -319,8 +311,7 @@
Code
print(dat)
Message
Forecast type:
point
Forecast type: point
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -359,8 +350,7 @@
Code
print(dat)
Message
Forecast type:
point
Forecast type: point
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -399,8 +389,7 @@
Code
print(dat)
Message
Forecast type:
point
Forecast type: point
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -439,8 +428,7 @@
Code
print(dat)
Message
Forecast type:
point
Forecast type: point
Forecast unit:
location, target_end_date, target_type, location_name, forecast_date, model,
and horizon
Expand Down Expand Up @@ -479,8 +467,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -518,8 +505,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -557,8 +543,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -596,8 +581,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -635,8 +619,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -674,8 +657,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -713,8 +695,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down Expand Up @@ -752,8 +733,7 @@
Code
print(dat)
Message
Forecast type:
sample
Forecast type: sample
Forecast unit:
location, location_name, target_end_date, target_type, forecast_date, model,
and horizon
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-summarise_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,12 @@ test_that("summarise_scores() across argument works as expected", {
scores, by = c("location", "target_type")
)
)

expect_warning(
summarise_scores(
scores, across = c("horizon", "model", "forecast_date", "target_end_date"),
by = c("model", "target_type")
),
"You specified `across` and `by` at the same time.`by` will be ignored"
)
})