Skip to content

Conversation

@bastings
Copy link
Contributor

@bastings bastings commented Nov 18, 2020

This adds an optimized LSTM cell that is compatible with the regular LSTMCell.
It is faster because multiple smaller matrix multiplications are combined into larger ones.
it is compatible because the parameter matrices are still saved individually and combined dynamically before the coputation, so LSTMCell and OptimizedLSTMCell can be exchanged without changing the computation. A test verifies this.

Work together with @adarob .

Context: https://twitter.com/avitaloliver/status/1328965366173851649?s=20

@google-cla google-cla bot added the cla: yes label Nov 18, 2020
@bastings bastings requested a review from marcvanzee November 18, 2020 10:47
@codecov-io
Copy link

codecov-io commented Nov 18, 2020

Codecov Report

Merging #648 (e7a8d33) into master (3362ce1) will increase coverage by 0.41%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #648      +/-   ##
==========================================
+ Coverage   80.60%   81.02%   +0.41%     
==========================================
  Files          55       55              
  Lines        4254     4347      +93     
==========================================
+ Hits         3429     3522      +93     
  Misses        825      825              
Impacted Files Coverage Δ
flax/linen/__init__.py 100.00% <100.00%> (ø)
flax/linen/recurrent.py 100.00% <100.00%> (ø)
flax/nn/recurrent.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3362ce1...e7a8d33. Read the comment docs.

@avital
Copy link
Contributor

avital commented Nov 18, 2020

Thanks @bastings! flax.nn is now deprecated, could you rewrite this on top of flax.linen? The Linen upgrade guide may be helpful.

@bastings
Copy link
Contributor Author

Thanks @bastings! flax.nn is now deprecated, could you rewrite this on top of flax.linen? The Linen upgrade guide may be helpful.

Oh shoot.. call me old fashioned. :-) Would it still be useful to have this in both nn AND linen, or just in linen?

@avital
Copy link
Contributor

avital commented Nov 18, 2020

If you already have the code on flax.nn and tests are passing, I guess there's no harm in adding it. But I wouldn't spend any additional work there.

@bastings
Copy link
Contributor Author

Sounds good. Will add the linen version. 👍

@avital
Copy link
Contributor

avital commented Nov 26, 2020

Hi @bastings -- looking forward to merging once we have a Linen version!

@avital avital added Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: in progress labels Nov 26, 2020
@avital
Copy link
Contributor

avital commented Nov 26, 2020

(Or feel free to make this as "pull requests welcome" as perhaps someone else can help with this, if you're too busy)

@bastings
Copy link
Contributor Author

I'll add it, just ran out of time that day :)

@bastings bastings marked this pull request as draft December 4, 2020 16:51
@bastings bastings marked this pull request as ready for review December 4, 2020 16:51
@bastings
Copy link
Contributor Author

bastings commented Dec 4, 2020

@avital @marcvanzee done!

Copy link
Contributor

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it is better to create your own fork of Flax and make your changes in your own branch. Otherwise we end up with a huge amount of branches in Flax.

return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


class DummyDense(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we are using this special Dense layer because we want to do the lax.dot_general outside of it. I would rename it to something more descriptive, maybe DenseNoMatMul or DenseNoDotGeneral? One could even argue whether it is still a Dense, since it seem you just get a kernel and bias, so a name KernelAndBias is also fitting, I think.

We could also consider making this class private to OptimizedLSTMCell.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to make it private to OptimizedLSTMCell but Flax doesn't like that (error), so I'll keep this outside.

I renamed to DenseParams because that is what it is. Does that work for you? I still like DummyDense too though, since it follows the Dense API, just doesn't apply it.

@bastings bastings requested a review from avital December 8, 2020 14:37
@bastings bastings requested a review from marcvanzee December 8, 2020 14:37
@bastings bastings removed the request for review from avital December 9, 2020 15:46
@copybara-service copybara-service bot merged commit 692a62c into master Dec 9, 2020
@copybara-service copybara-service bot deleted the optimized_lstm branch December 9, 2020 16:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants