Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get q,k,v? #6

Closed
liyiersan opened this issue Sep 18, 2021 · 2 comments
Closed

How to get q,k,v? #6

liyiersan opened this issue Sep 18, 2021 · 2 comments

Comments

@liyiersan
Copy link

liyiersan commented Sep 18, 2021

In your code, I do not understand how to get q, k, v from x and x_pooled. I have been confused by the roll operation on k_windows and the unfold operation on k_pooled_k for several days. Take stage 1 as an example, for level 0, since the sw is 1, I think the sr should be window_size//sw, that is 7. And for level 1, since the sw is 7, I think the sr should be output_size//sw, that is 8. Therefore the number of k is 7*7+8*8 = 113. But in your paper, you set sr as 13 at lavel 0 , sr as 7 at level 1. Why 13 and 7? And in your code, the number of k is 7*7+4*7*7-4*(7-3)*(7-3)+7*7=230, which is different from 7*7+13*13=218. As a suggestion, the window attention should be writen more clearly, and more comments are need in your code. Thanks a lot. If there is something wrong with what I said, please forgive me.

@liyiersan liyiersan changed the title How to get q,kv? How to get q,k,v? Sep 18, 2021
@jwyang
Copy link
Member

jwyang commented Sep 22, 2021

Hi, @liyiersan ,

Good point! Let me first answer your question "Why it is 230 in our implementation instead of 218?"

If we exactly follow the illustration in our paper, it is indeed 7x7 + 13x13 = 218 key and value tokens. However, this requires us to use unfold with window size= 13 for each local window. We found this is much more time-consuming especially for large feature maps. As such, we used an implementation to approximate this. That is, we use four rolling with shift size 3 to roll the feature map, and then masking out the overlapped part with the center 7x7 window. With these maskings, there is still a slight overlap between the edges of adjacent rolled windows, and the number is exactly 4x3=12. that's why you see we have 230=218+12 tokens.

To help you to understand this, I uploaded a profiling code here.

Let me know if you have further questions.

thanks,

@jwyang
Copy link
Member

jwyang commented Sep 28, 2021

I guess the above answer addressed this issue, going to close it.

@jwyang jwyang closed this as completed Sep 28, 2021
This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants