-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor shape mismatch when computing apply_rotary_pos_emb #2
Comments
I tried to comment the last line and received the same error again:
After a thorough investigation of the source code, I discovered that within the implementation of attention, the query and key are transformed into the following forms.
I printed the tensors' shape:
apply_rotary_pos_emb requires multiplying the cosine and query, which clearly do not match in shape. I'm uncertain about the original intention of the source code, hence unable to correct this issue on my own. |
In the function of "apply_rotary_pos_emb" we have position_ids to slice the cos and sin tensor to be aligned with query and keys
So I think it's impossible to have a shape-misalignment bug here. Can you go to apply_rotary_pos_emb and print the shape of the tensor inside? |
I check the apply_rotary_pos_emb but it seems a little bit different
I've discovered that this is a compatibility issue. I have now rolled back to transformers==4.36(which was 4.38), and that problem has disappeared, but now issue #1 has occured.
|
Oh, you need to install torch 2.1.2. Actually, only this torch version (and maybe 2.1.1) is compatible. I will deal with this later. But for now, you can turn to torch 2.1.2. |
Thank you for your response. After reconfiguring the environment, it indeed runs smoothly now. |
Description:
When I tried to reproduce the paper result by README, an exception raised:
I tracked the function calling and enabled the 'debug' flag in engine.model_run. When I tried it again, the assertion failed:
I checked the code and found a suspicious line in
capture_graph
:the last line changes static_attn_mask into shape of (1,1, x, y), which certainly fails the check.
The text was updated successfully, but these errors were encountered: