Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerNorm/GatedRMS inconsistency #1

Open
inspirit opened this issue Apr 1, 2022 · 6 comments
Open

LayerNorm/GatedRMS inconsistency #1

inspirit opened this issue Apr 1, 2022 · 6 comments

Comments

@inspirit
Copy link

inspirit commented Apr 1, 2022

Hi!
looking through pipeline it seems there are some inconsistencies with normalisation

# ReLA
input to GRMSNorm
# att code
output: Linear(inner_dim, dim) + GRMSNorm
# next in FF module 
input to LayerNorm

here we have problem with double norm since we have last layer GRMSNorm in att and first layer LayerNorm in FF.

looking at the paper it seems that in ReLA GRMSNorm is applied to result of mult(attn, v) before output projection not after projection like in this code.
I also confused about usage of LayerNorm in FF should it be GRMSNorm instead? not clear from the paper as well

@lucidrains
Copy link
Owner

@inspirit hello there! yea, i kind of did some improvisation there

i'm using the sandwich normalization formulation from another paper https://arxiv.org/abs/2105.13290 rather than just normalizing the aggregated values directly

for the feedforward, i'm not entirely sure, probably wouldn't make that huge of a difference

@inspirit
Copy link
Author

inspirit commented Apr 6, 2022

Aha I see, yup i remember sandwich norm paper :)
another difference I noticed: you use projection based gating (with Linear layer) in GRMSNorm, while original paper is using simple per element multiplication here: return normed_x *(x*gate).sigmoid() where gate = nn.Parameter(torch.tensor(dim))

@lucidrains
Copy link
Owner

@inspirit ohh apologies, yea, i didn't build that correctly b58b121

let me know if that works! i've seen the relu based attention in another recent paper https://github.com/lucidrains/FLASH-pytorch , so maybe there's something to it!

@lucidrains
Copy link
Owner

@inspirit how did it go? :) any interesting experimental results?

@inspirit
Copy link
Author

inspirit commented Apr 8, 2022

it seems to be less stable compared to normal softmax attention, I fused it with preceiver for my experiments, sometimes it gives slightly better results sometimes not :) the reason might be due to a small model inner dimension (128) and more sparse attention due to ReLu use

@lucidrains
Copy link
Owner

@inspirit yea, i thought it would be too good to be true if relu attention worked 😞 it must have worked for FLASH because they confine their quadratic attention to local windows

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants