Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in end-dec attention? #90

Closed
py4 opened this issue Apr 29, 2020 · 18 comments
Closed

Possible bug in end-dec attention? #90

py4 opened this issue Apr 29, 2020 · 18 comments

Comments

@py4
Copy link

py4 commented Apr 29, 2020

In the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (

x = torch.cat((x, mem, keys), dim=1)
) you are concating keys with x (where x is the decoder input) and then apply self-attention. Does it make sense to do self attention on decoder-input and encoder outputs? Because even in the trax codes these two are handled separately: (https://github.com/google/trax/blob/c7c47a14ef8ea5b260ac78c22cbadd6dc1fb605b/trax/models/reformer/reformer.py#L968) at first self attention is applied on the decoder input, and then a seperate encoder-decoder attention is applied between the new representation for decoder and the keys.

I don't if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.

def cycle():
    while True:
        source = torch.randint(2, 10, (32, 768)).long().cuda()
        target_np = np.flip(source.cpu().numpy(),axis=1).copy()   #Reverse of copy of numpy array of given tensor
        target = torch.from_numpy(target_np).long().cuda()

        mask = torch.ones(32, 768).bool().cuda()

        yield (source, target, mask)

# First example: Copy Reverse: 768 tokens - vocab size: 256

model = ReformerEncDec(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 1,
    enc_max_seq_len = 768,
    enc_heads=1,
    dec_num_tokens = 256,
    dec_depth = 1,
    dec_max_seq_len = 768,
    dec_heads=1,
).cuda()

#model = TrainingWrapper(model)
model.cuda()


# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        source, target, mask = next(cycle())
        loss = model(seq_in=source, seq_out=target, return_loss = True, enc_input_mask=mask)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()
@lucidrains
Copy link
Owner

@py4 hello and thanks for your interest in my library!

Firstly, I tried running the reverse task in colab, and the loss does start to drop at around 700-800 iterations, if you simply increase the number of heads on the encoder and decoder to 4 (I just randomly set it at that on the first try, so it could be lower). So it does work.

Second, this repository actually combines the causal and enc-dec attention together into one layer in the decoder. This was to simplify argument passing into the reversible net, although I did design a better argument routing system in a sister repository that I may port over https://github.com/lucidrains/sinkhorn-transformer/blob/master/sinkhorn_transformer/reversible.py#L8

Finally, if you do a bit more digging, you will find that in the official Trax repository, their encoder decoder is actually not using LSH attention, but full attention (just with reversible nets). To have a fair comparison, you can turn on full attention (with reversible nets as well) using the use_full_attn keywords, and I think you'll find that it converges much faster.

@lucidrains
Copy link
Owner

@py4 also, if you have an older version of the library, unfortunately there was a long-standing bug with the way contextual keys were handled. so please upgrade to the latest version!

@py4
Copy link
Author

py4 commented Apr 30, 2020

@lucidrains I enabled use_full_attn and also used 4 heads (which I believe should also work with 1 heads for this task) but it still takes to many steps for the loss to become close to 0 compared to trax code. In the trax code (1 heads, 1 encoder layer, 1 decoder layer, ...) it converges to 0 in less than 800 steps. Isn't this a sign of a bug or something?

@lucidrains
Copy link
Owner

@py4 I'm back! I decided to investigate, and you were right, it was not converging even after 10k iterations. Naturally, this led me to run a ton of experiments, removing features left and right, trying to isolate the root cause. And I think I found it!

The Reformer actually does not use the classic attention with query, keys, values in their own projected space. For LSH to work, the queries and keys share the same space, and this is called "shared-QK". It turns out the cause is the shared QK space, as once I rewrote it into QKV, it converged at around 1.2k steps.

I then spent a lot of time trying to figure out how to make QK space work, and just before giving up, I decided to try switching to axial positional embeddings. Lo and behold, it starts working, and converges at about twice the iterations (2.4k).

To replicate my findings, please upgrade to the newest library 0.23, and rerun the full shared-QK attention.

model = ReformerEncDec(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 1,
    enc_max_seq_len = 768,
    enc_heads=1,
    dec_num_tokens = 256,
    dec_depth = 1,
    dec_max_seq_len = 768,
    dec_heads=1,
    enc_use_full_attn = True,
    dec_use_full_attn = True
).cuda()

As for not being as fast as Trax, could you look into their repository and see if they are using shared QK or not? I suspect they are not, and their encoder / decoder is just the same as a vanilla transformer, but with reversibility.

@lucidrains
Copy link
Owner

Also, this is a great time to advertise my new work over at https://github.com/lucidrains/sinkhorn-transformer . Not only do I find it converges much faster than Reformer at reasonably large sequence lengths, but it does not need to be in shared-QK space, and seems more flexible to build upon. I tried out Sinkhorn Transformer on your toy task, and it converged in 1.2k as well.

@py4
Copy link
Author

py4 commented May 1, 2020

@lucidrains Thank you very much for your fast response and efforts!

I still think the problem is coming from enc-dec attention. There might be couple of issues there. For instance when you want to calculate q,k,v I believe you use the same projection matrix for encoder and decoder but I think they should be separate.
In the trax code, I changed the full attention in encoder to LSH. In the decoder I changed the full self attention to LSH but kept enc-dec attention to be full attention it is still converging to 0 in less than 800 steps.
It is not intuitive at all that axial encoding is useful for this "reverse" task. In the trax code I'm not using axial encoding for this toy task.

Are you still sure that concatenating encoder and decoder representations and applying self-attention on top of the merged queries, a good idea?

@lucidrains
Copy link
Owner

@py4 that wasn't the issue, even though it is more faithful to the original transformer to have the decoder take turns on focusing on itself and then context.

To give you an idea, when I tried it as one layer vs two in QKV space, having it as one layer only incurred a cost of 100 extra iterations on top of 1.2k for two layers.

@lucidrains
Copy link
Owner

@py4 I'm guessing you will like my other repository :D It is more faithful to the original transformer, with the decoder as you described

@lucidrains
Copy link
Owner

@py4 can you point me at which projection matrix you are talking about?

@lucidrains
Copy link
Owner

@py4 you should try the "enc-dec attention" as LSH for a fair comparison. The self attention in the decoder doesn't really learn anything in your task

@py4
Copy link
Author

py4 commented May 1, 2020

Here


in the x you have merged both encoder representations and decoder representation. I'm not sure if you can extract q,k,v from a shared w_q and w_k for both encoder and decoder representations.

@lucidrains
Copy link
Owner

@py4 Yea, and that's the main deficiency of Reformer. You will notice that they are using the full QKV for the "enc-dec attention", but really the paper was written about LSH and having things in shared QK space. I implemented the whole architecture as shared QK

@lucidrains
Copy link
Owner

@py4 My full attention was written as QK attention as well, to keep from having to add extra parameters

@lucidrains
Copy link
Owner

lucidrains commented May 1, 2020

Think about it, for long sequences, you run into the problems the paper is purportedly trying to solve, if you use full QKV enc-dec attention

@py4
Copy link
Author

py4 commented May 1, 2020

I think the concept of sharing Q and K is different what I'm talking about. What I'm saying is that when you want to calculate qk and v using x, you multiply x by a project matrix w_qk and w_v. right?
Your x is including both encoder output and decoder input. It is not intuitive to me to use a single w for both encoder and decoder. In the "self attention" all tensors are homogenous. They are all encoder representation or decoder representation. But here you have mixed them and you are applying self-attention on the mixture of encoder and decoder representations.

And yeah I know that in trax implementations, they have not used LSH for enc-dec attention and are using full attention. But even in this full attention, at first they calculate self-attention on decoder (which can be easily replaced by LSH as i did and nothing bad happens) and then apply a Enc-Dec attention.

What I'm saying is that maybe we should seperate w_qk and w_v for encoder and decoder when you want to apply self-attention on mixture of encoder and decoder representations.

@lucidrains
Copy link
Owner

lucidrains commented May 1, 2020

@py4 In a lot of projects with full attention, I have mixed self-attention and enc-dec attention together still with great results. I think the gradients will allow the encoder to adjust the space of the contextual keys to align with the decoder.

No worries, I will build it the way you describe eventually, since I was able to get argument routing working with reversible networks in my other project. How about I put this as a final todo item on the Projects tab?

edit - https://github.com/lucidrains/reformer-pytorch/projects/2

@lucidrains
Copy link
Owner

If you run the encoder / decoder from my Sinkhorn project against the Reformer on full attention, you will see they take about the same number of iterations for the task you have, even though Sinkhorn has self-attention and contextual attention separate

@py4
Copy link
Author

py4 commented May 1, 2020

Ok. Thank you :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants