Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-Scale Retention: Why include position embeddings explicitly? #48

Closed
fkodom opened this issue Aug 2, 2023 · 3 comments
Closed
Assignees

Comments

@fkodom
Copy link

fkodom commented Aug 2, 2023

My question is about the RetNet paper, which leads to the implementation here...

Why include the positional embedding updates directly in the multi-scale retention layer, rather than just applying them to the RetNet inputs?

Screen Shot 2023-08-02 at 9 35 48 AM

Screen Shot 2023-08-02 at 9 29 37 AM

IMO, this seems overly specific to the language modeling use case. Other applications of retention/attention should be free to use whatever positional embeddings they need/want.

The retention formulation is still self-consistent (i.e. equivalent for parallel, recurrent, chunkwise) without explicitly including positional embeddings in the retention layer. See Equations (1) and (2):

Screen Shot 2023-08-02 at 9 41 15 AM

Instead of forcing positional embeddings into the retention formulation, we can just set A equal to the decay matrix D. The parallel/recurrent/chunkwise formulations are still equivalent, and we remove the hard-coded dependence on xPos embeddings in the retention layer.

Conceptually, I'm thinking of how to apply RetNet to other data domains (images, heterogeneous graphs, etc). In those cases, the xPos embeddings are not reflective of the actual position in the sequence (2D position in image, generic position within a graph, etc). Does it make sense to remove the explicit position embedding from the retention layer, or am I missing something?

@sunyt32
Copy link
Contributor

sunyt32 commented Aug 3, 2023

$e^{i\theta}$ works well on language modeling, and we set it as default. For other domains, we don't evaluate on them yet, and I agree that rotation may not be the best option. Also, maybe an optimization technique is needed, but setting them as learnable parameters naively will cause nan in gradients. You can try to adjust it manually or explore a usable method to optimize it.

@donglixp
Copy link
Contributor

donglixp commented Aug 3, 2023

It depends on how you understand "position embeddings". For example, we can also add the position embeddings (such as "generic position within a graph") to the token embeddings, where the positions are regarded as attributes.

@fkodom
Copy link
Author

fkodom commented Aug 8, 2023

Thanks! This is exactly what I was looking for. 😎

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants