-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training API uses streams #60
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this together! I left a comment to see if we can address that before we merge. BTW, make sure to run the formatter so it passes the CI!
x when is_integer(x) -> | ||
x | ||
{model_state, avg_loss, total_batches} = | ||
for {{inp, tar}, i} <- dataset, reduce: {model_state, Nx.tensor(0.0), 0} do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this runs to the end of the stream? I think perhaps we'll want an option to terminate before the end of a stream. I'm trying to think of the best way to do that without loading n
steps all into memory at once with something like Enum.take
, any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do a Stream.take(n)
or do a throw/try+catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we go with Enum.reduce_while/3
? We just need the accumulator when processing each batch, but does that function keep the accumulator as well as the first n elements of the stream in memory at once? With the continue/halting function of reduce_while
it would be easy to use incorporate different early stopping criteria.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanmor5 @josevalim, do you want me to add options for early stopping in this PR? Otherwise it'll be a straightforward addition to convert the comprehension to reduce_while
later on.
Early stopping might be specified by a function that accepts a combination of current batch/step loss, previous batch loss, average loss, and batch index and returns a boolean indicating whether training should continue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@t-rutten I would hold off on that for now, I would like to implement that as a callback when those are added into the training API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense @seanmor5. Do you have any more suggestions for the changes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No changes, I'm happy with where this is, unless you have anything else you'd like to add, I'll merge :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have anything else to add :) Thanks!
Ahh I just realized the formatting failure is probably from my recent commit, in any case, feel free to run the formatter on everything so it passes before we merge :) |
Per discussion in #25, changes here adapt the training API to accept
inputs
andlabels
that are streams.