Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled dot product attention #2500

Conversation

mrityunjay-tripathi
Copy link
Member

@mrityunjay-tripathi mrityunjay-tripathi commented Jul 4, 2020

@lozhnikov The single head attention is working. Probably we can now use Concat layer to implement multihead attention #2375
🚀 🚀 I am really delighted that this is working. :)

Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the code looks good. I added a couple of comments.

@mrityunjay-tripathi
Copy link
Member Author

@lozhnikov The PyTorch implementation and test of scaled dot product attention is here. The results are matching. Can you have a look at this?

Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments.

Comment on lines 74 to 75
key = const_cast<arma::Mat<eT>&>(input);
value = const_cast<arma::Mat<eT>&>(input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need const_cast here? Why doesn't the following code work?

Suggested change
key = const_cast<arma::Mat<eT>&>(input);
value = const_cast<arma::Mat<eT>&>(input);
key = input;
value = input;

src/mlpack/tests/ann_layer_test.cpp Show resolved Hide resolved
src/mlpack/tests/ann_layer_test.cpp Show resolved Hide resolved

//! Test Backward function with mask.
module.Backward(input, gy, g);
expGrad = arma::mat("0.00000000 0.00000000;\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, do you have a notebook for this? I'd like to play with the values a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Here it is.

@lozhnikov
Copy link
Contributor

Looks like it's almost done. Let me look a couple of times.

mrityunjay-tripathi and others added 3 commits July 20, 2020 04:17
Co-authored-by: Mikhail Lozhnikov <lozhnikovma@gmail.com>
@mrityunjay-tripathi
Copy link
Member Author

mrityunjay-tripathi commented Jul 20, 2020

@lozhnikov Let me know if I need to clarify anything about the new changes. Basically I have tried to--

  1. Handle empty key and value in a better way.
  2. Get rid of code duplication.
  3. Get rid of using const_cast.

Any suggestions to make it better?

@mlpack-bot
Copy link

mlpack-bot bot commented Aug 25, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Aug 25, 2020
@lozhnikov
Copy link
Contributor

Keep open

@mlpack-bot mlpack-bot bot removed the s: stale label Aug 25, 2020
@mlpack-bot
Copy link

mlpack-bot bot commented Sep 24, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Sep 24, 2020
@lozhnikov
Copy link
Contributor

Keep open

@mlpack-bot mlpack-bot bot removed the s: stale label Sep 24, 2020
@mlpack-bot
Copy link

mlpack-bot bot commented Oct 24, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Oct 24, 2020
@mlpack-bot mlpack-bot bot closed this Oct 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants