Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training API uses streams #60

Merged
merged 3 commits into from
May 3, 2021
Merged

Training API uses streams #60

merged 3 commits into from
May 3, 2021

Conversation

t-rutten
Copy link
Contributor

Per discussion in #25, changes here adapt the training API to accept inputs and labels that are streams.

Copy link
Contributor

@seanmor5 seanmor5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this together! I left a comment to see if we can address that before we merge. BTW, make sure to run the formatter so it passes the CI!

x when is_integer(x) ->
x
{model_state, avg_loss, total_batches} =
for {{inp, tar}, i} <- dataset, reduce: {model_state, Nx.tensor(0.0), 0} do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this runs to the end of the stream? I think perhaps we'll want an option to terminate before the end of a stream. I'm trying to think of the best way to do that without loading n steps all into memory at once with something like Enum.take, any thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do a Stream.take(n) or do a throw/try+catch.

Copy link
Contributor Author

@t-rutten t-rutten Apr 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we go with Enum.reduce_while/3? We just need the accumulator when processing each batch, but does that function keep the accumulator as well as the first n elements of the stream in memory at once? With the continue/halting function of reduce_while it would be easy to use incorporate different early stopping criteria.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 @josevalim, do you want me to add options for early stopping in this PR? Otherwise it'll be a straightforward addition to convert the comprehension to reduce_while later on.

Early stopping might be specified by a function that accepts a combination of current batch/step loss, previous batch loss, average loss, and batch index and returns a boolean indicating whether training should continue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-rutten I would hold off on that for now, I would like to implement that as a callback when those are added into the training API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense @seanmor5. Do you have any more suggestions for the changes here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No changes, I'm happy with where this is, unless you have anything else you'd like to add, I'll merge :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have anything else to add :) Thanks!

@seanmor5
Copy link
Contributor

Thanks for putting this together! I left a comment to see if we can address that before we merge. BTW, make sure to run the formatter so it passes the CI!

Ahh I just realized the formatting failure is probably from my recent commit, in any case, feel free to run the formatter on everything so it passes before we merge :)

@seanmor5 seanmor5 merged commit 1371b36 into elixir-nx:main May 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants