Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source for the gated GELU MLP #48

Open
breuderink opened this issue Sep 9, 2021 · 1 comment
Open

Source for the gated GELU MLP #48

breuderink opened this issue Sep 9, 2021 · 1 comment

Comments

@breuderink
Copy link

breuderink commented Sep 9, 2021

Reading the code I found the following implementation for the feed-forward MLP of the Perceiver IO:

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

I could not find references to a gated GELU in the PerceiverIO paper nor in in the code.

Is there a particular to use GEGLU instead of GELU?

@lucidrains
Copy link
Owner

@breuderink ohh this is actually a trick from a Shazeer paper https://arxiv.org/pdf/2002.05202.pdf that should give an extra performance boost, but i should probably make it optional to stay faithful to the original paper

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants