-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scaled dot product attention #2500
scaled dot product attention #2500
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the code looks good. I added a couple of comments.
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
Co-authored-by: Mikhail Lozhnikov <lozhnikovma@gmail.com>
…tripathi/mlpack into scaled_dot_product_attention
@lozhnikov The PyTorch implementation and test of scaled dot product attention is here. The results are matching. Can you have a look at this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments.
key = const_cast<arma::Mat<eT>&>(input); | ||
value = const_cast<arma::Mat<eT>&>(input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need const_cast
here? Why doesn't the following code work?
key = const_cast<arma::Mat<eT>&>(input); | |
value = const_cast<arma::Mat<eT>&>(input); | |
key = input; | |
value = input; |
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/ann/layer/scaled_dot_product_attention_impl.hpp
Outdated
Show resolved
Hide resolved
|
||
//! Test Backward function with mask. | ||
module.Backward(input, gy, g); | ||
expGrad = arma::mat("0.00000000 0.00000000;\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, do you have a notebook for this? I'd like to play with the values a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. Here it is.
Looks like it's almost done. Let me look a couple of times. |
Co-authored-by: Mikhail Lozhnikov <lozhnikovma@gmail.com>
…_dot_product_attention
@lozhnikov Let me know if I need to clarify anything about the new changes. Basically I have tried to--
Any suggestions to make it better? |
…_dot_product_attention
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
Keep open |
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
Keep open |
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
@lozhnikov The single head attention is working. Probably we can now use
Concat
layer to implement multihead attention #2375🚀 🚀 I am really delighted that this is working. :)