Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements evaluate #73

Merged
merged 28 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
51adecf
Factor out the `ctx` preparation for validation.
dfalbel Aug 26, 2021
e57aea2
Factor out the `valid_loop` as a function of `ctx` and `step`.
dfalbel Aug 26, 2021
9c6743a
Refactor `step` acquisition.
dfalbel Aug 26, 2021
1fd096c
Move cleaning context code to the context class.
dfalbel Aug 26, 2021
663ce6a
Lock the `ctx` objects so users can't add elements to it.
dfalbel Aug 26, 2021
d67668e
Start adding checks for `ctx` attributes. This will allow for more ma…
dfalbel Aug 26, 2021
e600c82
make a single source of truth for `input`, `target` and `batch`.
dfalbel Aug 26, 2021
e5c09ae
Check when assigning epochs to context.
dfalbel Aug 27, 2021
8f086c6
move even more checks to the context object
dfalbel Aug 27, 2021
0e51a2d
Get `metrics` back to the root of callbacks.
dfalbel Aug 31, 2021
9620776
Refactor context initialization. Now fully delegating to `initialize`…
dfalbel Aug 31, 2021
b7a8bc2
Simplified the fitted model output.
dfalbel Sep 1, 2021
7f283a4
Run `devtools::document()`.
dfalbel Sep 1, 2021
f4a83b9
Improve context cleanup.
dfalbel Sep 2, 2021
cd3c3c9
Remove objects from the enclosing environment that leaks.
dfalbel Sep 8, 2021
29deaab
add evaluate tests and print method
dfalbel Sep 14, 2021
6dc8050
tweak metric name
dfalbel Sep 14, 2021
838be0b
fix error message
dfalbel Sep 14, 2021
375394a
update plot
dfalbel Sep 14, 2021
91d753c
Document the evaluate function.
dfalbel Sep 14, 2021
83d9634
Fix missing argument in get_metrics
dfalbel Sep 14, 2021
8e0d903
sprinf -> sprintf
dfalbel Sep 14, 2021
b697acb
refer to records directly and add test
dfalbel Sep 14, 2021
7818377
improve documentation for evaluate.
dfalbel Sep 14, 2021
8212e8a
add exit handlers to avoid large leaks due to #74
dfalbel Sep 14, 2021
a7d8ab2
add a test for the print method
dfalbel Sep 14, 2021
6cddd2d
Include NEWS bullets.
dfalbel Sep 14, 2021
68ee9d8
Fix typos in docs
dfalbel Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@ S3method(as_dataset,numeric)
S3method(as_dataset,torch_tensor)
S3method(as_iterator,device_dataloader)
S3method(fit,luz_module_generator)
S3method(get_metrics,luz_context)
S3method(get_metrics,luz_module_evaluation)
S3method(get_metrics,luz_module_fitted)
S3method(plot,lr_records)
S3method(plot,luz_module_fitted)
S3method(predict,luz_module_fitted)
S3method(print,lr_records)
S3method(print,luz_module_evaluation)
S3method(print,luz_module_fitted)
S3method(print,luz_module_generator)
export("%>%")
export(accelerator)
export(as_dataloader)
export(evaluate)
export(fit)
export(get_metrics)
export(lr_finder)
export(luz_callback)
export(luz_callback_csv_logger)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
* We now handle different kinds of data arguments passed to `fit` using the `as_dataloader()` method (#66).
* `valid_data` can now be scalar value indicating the proportion of `data` that will be used for fitting. This only works if `data` is a torch dataset or a list. (#69)
* You can now supply `dataloader_options` to `fit` to pass additional information to `as_dataloader()`. (#71)
* Refactored the `ctx` object to make it safer and avoid returing it in the output. (#73)
* Implemented the `evaluate` function allowing users to get metrics from a model in a new datase. (#73)

# luz 0.1.0

Expand Down
2 changes: 1 addition & 1 deletion R/as_dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ as_dataset <- function(x, ...) {

#' @export
as_dataset.default <- function(x, ...) {
rlang::abort(sprinf("Can't convert object with class '%s' to a torch dataset.", class(x)[1]))
rlang::abort(sprintf("Can't convert object with class '%s' to a torch dataset.", class(x)[1]))
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/callbacks-profile.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ luz_callback_profile <- luz_callback(
)

get_total_time <- function(x) {
unlist(x$ctx$records$profile$fit)
unlist(x$records$profile$fit)
}

get_average_time <- function(x, what) {
mean(unlist(x$ctx$records$profile[[what]]))
mean(unlist(x$records$profile[[what]]))
}
14 changes: 12 additions & 2 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ default_callbacks <- function() {

default_predict_callbacks <- function() {
list(
luz_callback_progress()
luz_callback_progress(),
luz_callback_interrupt()
)
}

default_evaluate_callbacks <- function() {
list(
luz_callback_profile(),
luz_callback_metrics(),
luz_callback_progress(),
luz_callback_interrupt()
)
}

Expand Down Expand Up @@ -535,7 +545,7 @@ luz_callback_lr_scheduler <- luz_callback(
if (is.null(self$opt_name) && (length(ctx$optimizers) == 1))
self$opt_name <- names(ctx$optimizers)
else
rlang::abort("An optimizer name was not supported and your model has multiple optimizers")
rlang::abort("An optimizer name was not supplied and your model has multiple optimizers")

if (!self$opt_name %in% names(ctx$optimizers))
rlang::abort(glue::glue("opt_name '{self$opt_name}' not found in ctx$optimizers."))
Expand Down