You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
At the end of a training run, TaskTrainer returns auxiliary information such as losses and intermediate model states in a TaskTrainerHistory object.
The history of the training loss is returned in the loss field, but the validation loss is not returned.
Validation is only performed once every log_step trials, so a little care will need to be taken that the returned arrays are comparable with those for the training loss; it might be appropriate to return arrays that are all NaN except on the log steps. Memory should not be an issue. Usually there won't be more than 5 loss terms or so, and (conservatively) less than 100,000 iterations. That's about 2 MB of float32.
The text was updated successfully, but these errors were encountered:
Currently, the returned arrays of validation losses are zero everywhere, except on batches where the model is validated. I avoided using jnp.nan so that I wouldn't need to alter the program's behaviour when troubleshooting NaNs arising elsewhere. Since zero is a valid loss, it might be better to switch to -1 or something.
Also added a switch to feedbax.plotly.loss_history to plot these validation losses. Unlike training losses, the validation losses are typically sparse over time, so in their case I've switched plotting from "markers+lines" to just "markers".
At the end of a training run,
TaskTrainer
returns auxiliary information such as losses and intermediate model states in aTaskTrainerHistory
object.The history of the training loss is returned in the
loss
field, but the validation loss is not returned.Validation is only performed once every
log_step
trials, so a little care will need to be taken that the returned arrays are comparable with those for the training loss; it might be appropriate to return arrays that are all NaN except on the log steps. Memory should not be an issue. Usually there won't be more than 5 loss terms or so, and (conservatively) less than 100,000 iterations. That's about 2 MB of float32.The text was updated successfully, but these errors were encountered: