Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PMA implementation missing rFF? #11

Closed
Timsey opened this issue Jan 26, 2021 · 2 comments
Closed

PMA implementation missing rFF? #11

Timsey opened this issue Jan 26, 2021 · 2 comments

Comments

@Timsey
Copy link

Timsey commented Jan 26, 2021

Dear Juho,

First of all, thank you for the implementation! It has been very helpful to my understanding of the architecture.

I ran into an alleged discrepancy between code and paper, and I was wondering if you could help clear this up. In particular, it seems to me that the PMA implementation is missing the row-wise feed-forward layer that is mentioned in the paper:

PMA(S, Z) = MAB(S, rFF(Z))

The PMA code:

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)

To me this reads PMA(S, X) = MAB(S, X), rather than the MAB(S, rFF(X)) of the paper.

Thanks!

Tim

@yoonholee
Copy link
Collaborator

Hi, you're right; the implementation there slightly diverges from what we described in the paper. In my experience, that change makes virtually no difference because the rFF at the end of the previous ISAB/SAB block serves the same role. Recovering the block in the paper should be a simple 2-line change: add a linear layer (nn.Linear(dim, dim)) and feed X through it before MAB.

@Timsey
Copy link
Author

Timsey commented Jan 28, 2021

Ah perfect: that makes sense, thanks!

@Timsey Timsey closed this as completed Jan 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants