Skip to content

Commit

Permalink
Merge pull request #73 from mlverse/feature/evaluate
Browse files Browse the repository at this point in the history
Implements `evaluate`
  • Loading branch information
dfalbel committed Sep 15, 2021
2 parents bd8f737 + 68ee9d8 commit 4b8ba20
Show file tree
Hide file tree
Showing 28 changed files with 992 additions and 191 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@ S3method(as_dataset,numeric)
S3method(as_dataset,torch_tensor)
S3method(as_iterator,device_dataloader)
S3method(fit,luz_module_generator)
S3method(get_metrics,luz_context)
S3method(get_metrics,luz_module_evaluation)
S3method(get_metrics,luz_module_fitted)
S3method(plot,lr_records)
S3method(plot,luz_module_fitted)
S3method(predict,luz_module_fitted)
S3method(print,lr_records)
S3method(print,luz_module_evaluation)
S3method(print,luz_module_fitted)
S3method(print,luz_module_generator)
export("%>%")
export(accelerator)
export(as_dataloader)
export(evaluate)
export(fit)
export(get_metrics)
export(lr_finder)
export(luz_callback)
export(luz_callback_csv_logger)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
* We now handle different kinds of data arguments passed to `fit` using the `as_dataloader()` method (#66).
* `valid_data` can now be scalar value indicating the proportion of `data` that will be used for fitting. This only works if `data` is a torch dataset or a list. (#69)
* You can now supply `dataloader_options` to `fit` to pass additional information to `as_dataloader()`. (#71)
* Refactored the `ctx` object to make it safer and avoid returing it in the output. (#73)
* Implemented the `evaluate` function allowing users to get metrics from a model in a new datase. (#73)

# luz 0.1.0

Expand Down
2 changes: 1 addition & 1 deletion R/as_dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ as_dataset <- function(x, ...) {

#' @export
as_dataset.default <- function(x, ...) {
rlang::abort(sprinf("Can't convert object with class '%s' to a torch dataset.", class(x)[1]))
rlang::abort(sprintf("Can't convert object with class '%s' to a torch dataset.", class(x)[1]))
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/callbacks-profile.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ luz_callback_profile <- luz_callback(
)

get_total_time <- function(x) {
unlist(x$ctx$records$profile$fit)
unlist(x$records$profile$fit)
}

get_average_time <- function(x, what) {
mean(unlist(x$ctx$records$profile[[what]]))
mean(unlist(x$records$profile[[what]]))
}
14 changes: 12 additions & 2 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ default_callbacks <- function() {

default_predict_callbacks <- function() {
list(
luz_callback_progress()
luz_callback_progress(),
luz_callback_interrupt()
)
}

default_evaluate_callbacks <- function() {
list(
luz_callback_profile(),
luz_callback_metrics(),
luz_callback_progress(),
luz_callback_interrupt()
)
}

Expand Down Expand Up @@ -535,7 +545,7 @@ luz_callback_lr_scheduler <- luz_callback(
if (is.null(self$opt_name) && (length(ctx$optimizers) == 1))
self$opt_name <- names(ctx$optimizers)
else
rlang::abort("An optimizer name was not supported and your model has multiple optimizers")
rlang::abort("An optimizer name was not supplied and your model has multiple optimizers")

if (!self$opt_name %in% names(ctx$optimizers))
rlang::abort(glue::glue("opt_name '{self$opt_name}' not found in ctx$optimizers."))
Expand Down

0 comments on commit 4b8ba20

Please sign in to comment.