Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom position offset when rotating queries or keys #2

Closed
wants to merge 1 commit into from

Conversation

krasserm
Copy link

@krasserm krasserm commented Dec 2, 2022

This library seems to assume that queries and keys are left-aligned position-wise e.g.

q = [p_0, p_1, p_2]
k = [p_0, p_1, p_2, p_3, p_4]

where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

q =           [p_2, p_3, p_4]
k = [p_0, p_1, p_2, p_3, p_4]

This pull requests allows to specify a start position for queries and or keys to enable alignments other than
left-alignments. For example

import torch
from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding

rot = RotaryEmbedding(dim=32)

q = torch.ones(1, 8, 4, 32)
k = torch.ones(1, 8, 6, 32)

q = q / torch.norm(q, dim=-1, keepdim=True)
k = k / torch.norm(k, dim=-1, keepdim=True)

q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
k_rot = rot.rotate_queries_or_keys(k)

attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
print(attn[0, 0])

prints the following relative position embedding

tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
        [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
        [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
        [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])

(diagonal of 1s right-aligned) whereas the default behavior

...

q_rot = rot.rotate_queries_or_keys(q)
k_rot = rot.rotate_queries_or_keys(k)

attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
print(attn[0, 0])

would print

tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
        [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
        [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
        [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])

(diagonal of 1s left-aligned).

@krasserm
Copy link
Author

krasserm commented Dec 6, 2022

Nevermind, I'll use the lower-level API e.g.

...

q_len = q.shape[2]
k_len = k.shape[2]
start_pos = k_len - q_len

freq_q = rot(torch.arange(start_pos, start_pos + q_len))
freq_k = rot(torch.arange(k_len))

q_rot = apply_rotary_emb(freq_q, q)
k_rot = apply_rotary_emb(freq_k, k)

...

It just would have been more convenient with a start_pos argument for rotate_queries_or_keys().

@krasserm krasserm closed this Dec 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant