Skip to content

Commit d4de0bf

Browse files
authored
fix(tb-callback): log train loss every epoch (#405)
fix tb callback
1 parent bde52b8 commit d4de0bf

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

R/CallbackSetTB.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ CallbackSetTB = R6Class("CallbackSetTB",
2828
dir.create(path, recursive = TRUE)
2929
}
3030
self$log_train_loss = assert_flag(log_train_loss)
31+
if (self$log_train_loss) {
32+
self$on_batch_end = function() {
33+
private$.log_train_loss()
34+
}
35+
}
3136
},
3237
#' @description
3338
#' Logs the training loss, training measures, and validation measures as TensorBoard events.
3439
on_epoch_end = function() {
35-
if (self$log_train_loss) {
36-
private$.log_train_loss()
37-
}
38-
3940
if (length(self$ctx$last_scores_train)) {
4041
walk(names(self$ctx$measures_train), private$.log_train_score)
4142
}

0 commit comments

Comments
 (0)