-
Notifications
You must be signed in to change notification settings - Fork 787
Add optimized LSTM cell. #648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #648 +/- ##
==========================================
+ Coverage 80.60% 81.02% +0.41%
==========================================
Files 55 55
Lines 4254 4347 +93
==========================================
+ Hits 3429 3522 +93
Misses 825 825
Continue to review full report at Codecov.
|
|
Thanks @bastings! |
Oh shoot.. call me old fashioned. :-) Would it still be useful to have this in both |
|
If you already have the code on |
|
Sounds good. Will add the |
|
Hi @bastings -- looking forward to merging once we have a Linen version! |
|
(Or feel free to make this as "pull requests welcome" as perhaps someone else can help with this, if you're too busy) |
|
I'll add it, just ran out of time that day :) |
|
@avital @marcvanzee done! |
marcvanzee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it is better to create your own fork of Flax and make your changes in your own branch. Otherwise we end up with a huge amount of branches in Flax.
flax/linen/recurrent.py
Outdated
| return init_fn(key1, mem_shape), init_fn(key2, mem_shape) | ||
|
|
||
|
|
||
| class DummyDense(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we are using this special Dense layer because we want to do the lax.dot_general outside of it. I would rename it to something more descriptive, maybe DenseNoMatMul or DenseNoDotGeneral? One could even argue whether it is still a Dense, since it seem you just get a kernel and bias, so a name KernelAndBias is also fitting, I think.
We could also consider making this class private to OptimizedLSTMCell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make it private to OptimizedLSTMCell but Flax doesn't like that (error), so I'll keep this outside.
I renamed to DenseParams because that is what it is. Does that work for you? I still like DummyDense too though, since it follows the Dense API, just doesn't apply it.
Co-authored-by: Marc van Zee <marcvanzee@google.com>
This adds an optimized LSTM cell that is compatible with the regular LSTMCell.
It is faster because multiple smaller matrix multiplications are combined into larger ones.
it is compatible because the parameter matrices are still saved individually and combined dynamically before the coputation, so LSTMCell and OptimizedLSTMCell can be exchanged without changing the computation. A test verifies this.
Work together with @adarob .
Context: https://twitter.com/avitaloliver/status/1328965366173851649?s=20