New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible bug in end-dec attention? #90
Comments
@py4 hello and thanks for your interest in my library! Firstly, I tried running the reverse task in colab, and the loss does start to drop at around 700-800 iterations, if you simply increase the number of heads on the encoder and decoder to 4 (I just randomly set it at that on the first try, so it could be lower). So it does work. Second, this repository actually combines the causal and enc-dec attention together into one layer in the decoder. This was to simplify argument passing into the reversible net, although I did design a better argument routing system in a sister repository that I may port over https://github.com/lucidrains/sinkhorn-transformer/blob/master/sinkhorn_transformer/reversible.py#L8 Finally, if you do a bit more digging, you will find that in the official Trax repository, their encoder decoder is actually not using LSH attention, but full attention (just with reversible nets). To have a fair comparison, you can turn on full attention (with reversible nets as well) using the |
@py4 also, if you have an older version of the library, unfortunately there was a long-standing bug with the way contextual keys were handled. so please upgrade to the latest version! |
@lucidrains I enabled |
@py4 I'm back! I decided to investigate, and you were right, it was not converging even after 10k iterations. Naturally, this led me to run a ton of experiments, removing features left and right, trying to isolate the root cause. And I think I found it! The Reformer actually does not use the classic attention with query, keys, values in their own projected space. For LSH to work, the queries and keys share the same space, and this is called "shared-QK". It turns out the cause is the shared QK space, as once I rewrote it into QKV, it converged at around 1.2k steps. I then spent a lot of time trying to figure out how to make QK space work, and just before giving up, I decided to try switching to axial positional embeddings. Lo and behold, it starts working, and converges at about twice the iterations (2.4k). To replicate my findings, please upgrade to the newest library 0.23, and rerun the full shared-QK attention. model = ReformerEncDec(
dim = 512,
enc_num_tokens = 256,
enc_depth = 1,
enc_max_seq_len = 768,
enc_heads=1,
dec_num_tokens = 256,
dec_depth = 1,
dec_max_seq_len = 768,
dec_heads=1,
enc_use_full_attn = True,
dec_use_full_attn = True
).cuda() As for not being as fast as Trax, could you look into their repository and see if they are using shared QK or not? I suspect they are not, and their encoder / decoder is just the same as a vanilla transformer, but with reversibility. |
Also, this is a great time to advertise my new work over at https://github.com/lucidrains/sinkhorn-transformer . Not only do I find it converges much faster than Reformer at reasonably large sequence lengths, but it does not need to be in shared-QK space, and seems more flexible to build upon. I tried out Sinkhorn Transformer on your toy task, and it converged in 1.2k as well. |
@lucidrains Thank you very much for your fast response and efforts! I still think the problem is coming from enc-dec attention. There might be couple of issues there. For instance when you want to calculate q,k,v I believe you use the same projection matrix for encoder and decoder but I think they should be separate. Are you still sure that concatenating encoder and decoder representations and applying self-attention on top of the merged queries, a good idea? |
@py4 that wasn't the issue, even though it is more faithful to the original transformer to have the decoder take turns on focusing on itself and then context. To give you an idea, when I tried it as one layer vs two in QKV space, having it as one layer only incurred a cost of 100 extra iterations on top of 1.2k for two layers. |
@py4 I'm guessing you will like my other repository :D It is more faithful to the original transformer, with the decoder as you described |
@py4 can you point me at which projection matrix you are talking about? |
@py4 you should try the "enc-dec attention" as LSH for a fair comparison. The self attention in the decoder doesn't really learn anything in your task |
Here
in the x you have merged both encoder representations and decoder representation. I'm not sure if you can extract q,k,v from a shared w_q and w_k for both encoder and decoder representations. |
@py4 Yea, and that's the main deficiency of Reformer. You will notice that they are using the full QKV for the "enc-dec attention", but really the paper was written about LSH and having things in shared QK space. I implemented the whole architecture as shared QK |
@py4 My full attention was written as QK attention as well, to keep from having to add extra parameters |
Think about it, for long sequences, you run into the problems the paper is purportedly trying to solve, if you use full QKV enc-dec attention |
I think the concept of sharing Q and K is different what I'm talking about. What I'm saying is that when you want to calculate qk and v using x, you multiply x by a project matrix w_qk and w_v. right? And yeah I know that in trax implementations, they have not used LSH for enc-dec attention and are using full attention. But even in this full attention, at first they calculate self-attention on decoder (which can be easily replaced by LSH as i did and nothing bad happens) and then apply a Enc-Dec attention. What I'm saying is that maybe we should seperate w_qk and w_v for encoder and decoder when you want to apply self-attention on mixture of encoder and decoder representations. |
@py4 In a lot of projects with full attention, I have mixed self-attention and enc-dec attention together still with great results. I think the gradients will allow the encoder to adjust the space of the contextual keys to align with the decoder. No worries, I will build it the way you describe eventually, since I was able to get argument routing working with reversible networks in my other project. How about I put this as a final todo item on the Projects tab? edit - https://github.com/lucidrains/reformer-pytorch/projects/2 |
If you run the encoder / decoder from my Sinkhorn project against the Reformer on full attention, you will see they take about the same number of iterations for the task you have, even though Sinkhorn has self-attention and contextual attention separate |
Ok. Thank you :) |
In the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (
reformer-pytorch/reformer_pytorch/reformer_pytorch.py
Line 598 in 5f5bbf4
I don't if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.
The text was updated successfully, but these errors were encountered: